Skip to content
Merged

fix #639

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ pub fn bind_call_expr_stat(

bind_assert_stat(binder, arg_list, current)
} else if call_expr.is_error() {
if let Some(ast) = LuaAst::cast(call_expr.syntax().clone()) {
bind_each_child(binder, ast, current);
}
let return_flow_id = binder.create_return();
binder.add_antecedent(return_flow_id, current);
return_flow_id
Expand Down
69 changes: 69 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1158,4 +1158,73 @@ end
let a = ws.expr_ty("A");
assert_eq!(ws.humanize_type(a), "number");
}

#[test]
fn test_type_narrow() {
let mut ws = VirtualWorkspace::new();
ws.def(
r#"
---@generic T: table
---@param obj T | function
---@return T?
function bindGC(obj)
if type(obj) == 'table' then
A = obj
end
end
"#,
);
let a = ws.expr_ty("A");
assert_eq!(ws.humanize_type(a), "T");
}

#[test]
fn test_issue_630() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
ws.def(
r#"
---@class A
---@field Abc string?
A = {}
"#,
);
ws.def(
r#"
function A:test()
if not rawget(self, 'Abc') then
self.Abc = "a"
end

B = self.Abc
C = self
end
"#,
);
let a = ws.expr_ty("B");
assert_eq!(ws.humanize_type(a), "string");
let c = ws.expr_ty("C");
assert_eq!(ws.humanize_type(c), "A");
}

#[test]
fn test_error_function() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();
assert!(ws.check_code_for(
DiagnosticCode::NeedCheckNil,
r#"
---@class Result
---@field value string?
Result = {}

function getValue()
---@type Result?
local result

if result then
error(result.value)
end
end
"#,
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use crate::{
narrow::{
condition_flow::InferConditionFlow, get_single_antecedent,
get_type_at_cast_flow::cast_type, get_type_at_flow::get_type_at_flow,
var_ref_id::get_var_expr_var_ref_id, ResultTypeOrContinue,
narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id,
ResultTypeOrContinue,
},
VarRefId,
},
DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaFunctionType, LuaInferCache,
LuaSignatureCast, LuaSignatureId, LuaType, TypeOps,
DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType,
LuaFunctionType, LuaInferCache, LuaSignatureCast, LuaSignatureId, LuaType, TypeOps,
};

pub fn get_type_at_call_expr(
Expand Down Expand Up @@ -58,18 +59,34 @@ pub fn get_type_at_call_expr(
};

let ret = signature.get_return_type();
if let LuaType::TypeGuard(_) = ret {
return get_type_at_call_expr_by_type_guard(
db,
tree,
cache,
root,
var_ref_id,
flow_node,
call_expr,
signature.to_doc_func_type(),
condition_flow,
);
match ret {
LuaType::TypeGuard(_) => {
return get_type_at_call_expr_by_type_guard(
db,
tree,
cache,
root,
var_ref_id,
flow_node,
call_expr,
signature.to_doc_func_type(),
condition_flow,
);
}
LuaType::Call(call) => {
return get_type_at_call_expr_by_call(
db,
tree,
cache,
root,
var_ref_id,
flow_node,
call_expr,
&call,
condition_flow,
);
}
_ => {}
}

let Some(signature_cast) = db
Expand Down Expand Up @@ -288,3 +305,40 @@ fn get_type_at_call_expr_by_signature_param_name(
)?;
Ok(ResultTypeOrContinue::Result(result_type))
}

#[allow(unused)]
fn get_type_at_call_expr_by_call(
db: &DbIndex,
tree: &FlowTree,
cache: &mut LuaInferCache,
root: &LuaChunk,
var_ref_id: &VarRefId,
flow_node: &FlowNode,
call_expr: LuaCallExpr,
alias_call_type: &Arc<LuaAliasCallType>,
condition_flow: InferConditionFlow,
) -> Result<ResultTypeOrContinue, InferFailReason> {
let Some(maybe_ref_id) =
get_var_expr_var_ref_id(db, cache, LuaExpr::CallExpr(call_expr.clone()))
else {
return Ok(ResultTypeOrContinue::Continue);
};

if maybe_ref_id != *var_ref_id {
return Ok(ResultTypeOrContinue::Continue);
}

match alias_call_type.get_call_kind() {
LuaAliasCallKind::RawGet => {
let antecedent_type = infer_expr(db, cache, LuaExpr::CallExpr(call_expr))?;
let result_type = match condition_flow {
InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type),
InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type),
};
return Ok(ResultTypeOrContinue::Result(result_type));
}
_ => {}
};

Ok(ResultTypeOrContinue::Continue)
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ pub fn narrow_down_type(db: &DbIndex, source: LuaType, target: LuaType) -> Optio
LuaType::Table | LuaType::Userdata | LuaType::Any | LuaType::Unknown => {
return Some(LuaType::Table);
}
// TODO: 应该根据模板约束进行精确匹配
LuaType::TplRef(_) => return Some(source),
LuaType::Global
| LuaType::Array(_)
| LuaType::Tuple(_)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::ops::Deref;

use emmylua_parser::{LuaAstNode, LuaExpr};
use emmylua_parser::{LuaAstNode, LuaCallExpr, LuaExpr, LuaLiteralToken, PathTrait};
use internment::ArcIntern;
use rowan::TextSize;
use smol_str::SmolStr;

use crate::{
infer_expr,
semantic::infer::{
infer_index::get_index_expr_var_ref_id, infer_name::get_name_expr_var_ref_id,
},
DbIndex, LuaDeclId, LuaDeclOrMemberId, LuaInferCache, LuaMemberId,
DbIndex, LuaAliasCallKind, LuaDeclId, LuaDeclOrMemberId, LuaInferCache, LuaMemberId, LuaType,
};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -62,6 +63,72 @@ impl VarRefId {
}
}

fn get_call_expr_var_ref_id(
db: &DbIndex,
cache: &mut LuaInferCache,
call_expr: &LuaCallExpr,
) -> Option<VarRefId> {
let Some(prefix_expr) = call_expr.get_prefix_expr() else {
return None;
};
let maybe_func = infer_expr(db, cache, prefix_expr.clone()).ok()?;

let ret = match maybe_func {
LuaType::DocFunction(f) => f.get_ret().clone(),
LuaType::Signature(signature_id) => db
.get_signature_index()
.get(&signature_id)?
.get_return_type(),
_ => return None,
};
let LuaType::Call(alias_call_type) = ret else {
return None;
};

match alias_call_type.get_call_kind() {
LuaAliasCallKind::RawGet => {
let args_list = call_expr.get_args_list()?;
let mut args_iter = args_list.get_args();

let obj_expr = args_iter.next()?;
let decl_or_member_id = match get_var_expr_var_ref_id(db, cache, obj_expr.clone()) {
Some(VarRefId::SelfRef(decl_or_id)) => decl_or_id,
Some(VarRefId::VarRef(decl_id)) => LuaDeclOrMemberId::Decl(decl_id),
_ => return None,
};
// 开始构建 access_path
let mut access_path = String::new();
access_path.push_str(obj_expr.syntax().text().to_string().as_str()); // 这里不需要精确的文本
access_path.push_str(".");
let key_expr = args_iter.next()?;
match key_expr {
LuaExpr::LiteralExpr(literal_expr) => match literal_expr.get_literal()? {
LuaLiteralToken::String(string_token) => {
access_path.push_str(string_token.get_value().as_str());
}
LuaLiteralToken::Number(number_token) => {
access_path.push_str(number_token.get_int_value().to_string().as_str());
}
_ => return None,
},
LuaExpr::NameExpr(name_expr) => {
access_path.push_str(name_expr.get_access_path()?.as_str());
}
LuaExpr::IndexExpr(index_expr) => {
access_path.push_str(index_expr.get_access_path()?.as_str());
}
_ => return None,
}

Some(VarRefId::IndexRef(
decl_or_member_id,
ArcIntern::new(SmolStr::new(access_path)),
))
}
_ => return None,
}
}

pub fn get_var_expr_var_ref_id(
db: &DbIndex,
cache: &mut LuaInferCache,
Expand All @@ -74,6 +141,7 @@ pub fn get_var_expr_var_ref_id(
let ref_id = match &var_expr {
LuaExpr::NameExpr(name_expr) => get_name_expr_var_ref_id(db, cache, name_expr),
LuaExpr::IndexExpr(index_expr) => get_index_expr_var_ref_id(db, cache, index_expr),
LuaExpr::CallExpr(call_expr) => get_call_expr_var_ref_id(db, cache, call_expr),
_ => None,
}?;

Expand Down
Loading