Skip to content

Commit fabb6b5

Browse files
committed
Fix return cast usage
Fix #720
1 parent 6e39fbe commit fabb6b5

File tree

3 files changed

+52
-14
lines changed

3 files changed

+52
-14
lines changed

crates/emmylua_code_analysis/src/compilation/test/flow.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1284,4 +1284,35 @@ end
12841284
let a = ws.expr_ty("A");
12851285
assert_eq!(ws.humanize_type(a), "Node");
12861286
}
1287+
1288+
#[test]
1289+
fn test_return_cast_multi_file() {
1290+
let mut ws = VirtualWorkspace::new();
1291+
ws.def_file(
1292+
"test.lua",
1293+
r#"
1294+
local M = {}
1295+
1296+
--- @return boolean
1297+
--- @return_cast _obj function
1298+
function M.is_callable(_obj) end
1299+
1300+
return M
1301+
"#,
1302+
);
1303+
ws.def(
1304+
r#"
1305+
local test = require("test")
1306+
1307+
local obj
1308+
1309+
if test.is_callable(obj) then
1310+
o = obj
1311+
end
1312+
"#,
1313+
);
1314+
let a = ws.expr_ty("o");
1315+
let expected = LuaType::Function;
1316+
assert_eq!(a, expected);
1317+
}
12871318
}

crates/emmylua_code_analysis/src/db_index/flow/mod.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,10 @@ impl LuaFlowIndex {
3434
self.file_flow_tree.get(file_id)
3535
}
3636

37-
pub fn get_signature_cast(
38-
&self,
39-
file_id: &FileId,
40-
signature_id: &LuaSignatureId,
41-
) -> Option<&LuaSignatureCast> {
42-
self.signature_cast_cache.get(file_id)?.get(signature_id)
37+
pub fn get_signature_cast(&self, signature_id: &LuaSignatureId) -> Option<&LuaSignatureCast> {
38+
self.signature_cast_cache
39+
.get(&signature_id.get_file_id())?
40+
.get(signature_id)
4341
}
4442

4543
pub fn add_signature_cast(

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,7 @@ pub fn get_type_at_call_expr(
8888
_ => {}
8989
}
9090

91-
let Some(signature_cast) = db
92-
.get_flow_index()
93-
.get_signature_cast(&cache.get_file_id(), &signature_id)
94-
else {
91+
let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) else {
9592
return Ok(ResultTypeOrContinue::Continue);
9693
};
9794

@@ -105,6 +102,7 @@ pub fn get_type_at_call_expr(
105102
flow_node,
106103
prefix_expr,
107104
signature_cast,
105+
signature_id,
108106
condition_flow,
109107
),
110108
name => get_type_at_call_expr_by_signature_param_name(
@@ -200,6 +198,7 @@ fn get_type_at_call_expr_by_signature_self(
200198
flow_node: &FlowNode,
201199
call_prefix: LuaExpr,
202200
signature_cast: &LuaSignatureCast,
201+
signature_id: LuaSignatureId,
203202
condition_flow: InferConditionFlow,
204203
) -> Result<ResultTypeOrContinue, InferFailReason> {
205204
let LuaExpr::IndexExpr(call_prefix_index) = call_prefix else {
@@ -221,13 +220,18 @@ fn get_type_at_call_expr_by_signature_self(
221220
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
222221
let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
223222

224-
let Some(cast_op_type) = signature_cast.cast.to_node(root) else {
223+
let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else {
224+
return Ok(ResultTypeOrContinue::Continue);
225+
};
226+
227+
let signature_root = syntax_tree.get_chunk_node();
228+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
225229
return Ok(ResultTypeOrContinue::Continue);
226230
};
227231

228232
let result_type = cast_type(
229233
db,
230-
cache.get_file_id(),
234+
signature_id.get_file_id(),
231235
cast_op_type,
232236
antecedent_type,
233237
condition_flow,
@@ -291,13 +295,18 @@ fn get_type_at_call_expr_by_signature_param_name(
291295
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
292296
let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
293297

294-
let Some(cast_op_type) = signature_cast.cast.to_node(root) else {
298+
let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else {
299+
return Ok(ResultTypeOrContinue::Continue);
300+
};
301+
302+
let signature_root = syntax_tree.get_chunk_node();
303+
let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else {
295304
return Ok(ResultTypeOrContinue::Continue);
296305
};
297306

298307
let result_type = cast_type(
299308
db,
300-
cache.get_file_id(),
309+
signature_id.get_file_id(),
301310
cast_op_type,
302311
antecedent_type,
303312
condition_flow,

0 commit comments

Comments
 (0)