diff --git a/crates/ty_python_semantic/src/dunder_all.rs b/crates/ty_python_semantic/src/dunder_all.rs index 39dabcb098fc0d..caf71b0a4dc6c0 100644 --- a/crates/ty_python_semantic/src/dunder_all.rs +++ b/crates/ty_python_semantic/src/dunder_all.rs @@ -6,9 +6,7 @@ use ruff_python_ast::name::Name; use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt}; use ruff_python_ast::{self as ast}; -use crate::semantic_index::ast_ids::HasScopedExpressionId; -use crate::semantic_index::place::ScopeId; -use crate::semantic_index::{SemanticIndex, global_scope, semantic_index}; +use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::types::{Truthiness, Type, infer_expression_types}; use crate::{Db, ModuleName, resolve_module}; @@ -44,11 +42,6 @@ struct DunderAllNamesCollector<'db> { db: &'db dyn Db, file: File, - /// The scope in which the `__all__` names are being collected from. - /// - /// This is always going to be the global scope of the module. - scope: ScopeId<'db>, - /// The semantic index for the module. index: &'db SemanticIndex<'db>, @@ -68,7 +61,6 @@ impl<'db> DunderAllNamesCollector<'db> { Self { db, file, - scope: global_scope(db, file), index, origin: None, invalid: false, @@ -190,8 +182,7 @@ impl<'db> DunderAllNamesCollector<'db> { /// /// This function panics if `expr` was not marked as a standalone expression during semantic indexing. fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { - infer_expression_types(self.db, self.index.expression(expr)) - .expression_type(expr.scoped_expression_id(self.db, self.scope)) + infer_expression_types(self.db, self.index.expression(expr)).expression_type(expr) } /// Evaluate the given expression and return its truthiness. diff --git a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs index 829c62c877a244..6d12d2291612e2 100644 --- a/crates/ty_python_semantic/src/semantic_index/ast_ids.rs +++ b/crates/ty_python_semantic/src/semantic_index/ast_ids.rs @@ -26,20 +26,11 @@ use crate::semantic_index::semantic_index; /// ``` #[derive(Debug, salsa::Update, get_size2::GetSize)] pub(crate) struct AstIds { - /// Maps expressions to their expression id. - expressions_map: FxHashMap, /// Maps expressions which "use" a place (that is, [`ast::ExprName`], [`ast::ExprAttribute`] or [`ast::ExprSubscript`]) to a use id. uses_map: FxHashMap, } impl AstIds { - fn expression_id(&self, key: impl Into) -> ScopedExpressionId { - let key = &key.into(); - *self.expressions_map.get(key).unwrap_or_else(|| { - panic!("Could not find expression ID for {key:?}"); - }) - } - fn use_id(&self, key: impl Into) -> ScopedUseId { self.uses_map[&key.into()] } @@ -94,90 +85,12 @@ impl HasScopedUseId for ast::ExprRef<'_> { } } -/// Uniquely identifies an [`ast::Expr`] in a [`crate::semantic_index::place::FileScopeId`]. -#[newtype_index] -#[derive(salsa::Update, get_size2::GetSize)] -pub struct ScopedExpressionId; - -pub trait HasScopedExpressionId { - /// Returns the ID that uniquely identifies the node in `scope`. - fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId; -} - -impl HasScopedExpressionId for Box { - fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId { - self.as_ref().scoped_expression_id(db, scope) - } -} - -macro_rules! impl_has_scoped_expression_id { - ($ty: ty) => { - impl HasScopedExpressionId for $ty { - fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId { - let expression_ref = ExprRef::from(self); - expression_ref.scoped_expression_id(db, scope) - } - } - }; -} - -impl_has_scoped_expression_id!(ast::ExprBoolOp); -impl_has_scoped_expression_id!(ast::ExprName); -impl_has_scoped_expression_id!(ast::ExprBinOp); -impl_has_scoped_expression_id!(ast::ExprUnaryOp); -impl_has_scoped_expression_id!(ast::ExprLambda); -impl_has_scoped_expression_id!(ast::ExprIf); -impl_has_scoped_expression_id!(ast::ExprDict); -impl_has_scoped_expression_id!(ast::ExprSet); -impl_has_scoped_expression_id!(ast::ExprListComp); -impl_has_scoped_expression_id!(ast::ExprSetComp); -impl_has_scoped_expression_id!(ast::ExprDictComp); -impl_has_scoped_expression_id!(ast::ExprGenerator); -impl_has_scoped_expression_id!(ast::ExprAwait); -impl_has_scoped_expression_id!(ast::ExprYield); -impl_has_scoped_expression_id!(ast::ExprYieldFrom); -impl_has_scoped_expression_id!(ast::ExprCompare); -impl_has_scoped_expression_id!(ast::ExprCall); -impl_has_scoped_expression_id!(ast::ExprFString); -impl_has_scoped_expression_id!(ast::ExprStringLiteral); -impl_has_scoped_expression_id!(ast::ExprBytesLiteral); -impl_has_scoped_expression_id!(ast::ExprNumberLiteral); -impl_has_scoped_expression_id!(ast::ExprBooleanLiteral); -impl_has_scoped_expression_id!(ast::ExprNoneLiteral); -impl_has_scoped_expression_id!(ast::ExprEllipsisLiteral); -impl_has_scoped_expression_id!(ast::ExprAttribute); -impl_has_scoped_expression_id!(ast::ExprSubscript); -impl_has_scoped_expression_id!(ast::ExprStarred); -impl_has_scoped_expression_id!(ast::ExprNamed); -impl_has_scoped_expression_id!(ast::ExprList); -impl_has_scoped_expression_id!(ast::ExprTuple); -impl_has_scoped_expression_id!(ast::ExprSlice); -impl_has_scoped_expression_id!(ast::ExprIpyEscapeCommand); -impl_has_scoped_expression_id!(ast::Expr); - -impl HasScopedExpressionId for ast::ExprRef<'_> { - fn scoped_expression_id(&self, db: &dyn Db, scope: ScopeId) -> ScopedExpressionId { - let ast_ids = ast_ids(db, scope); - ast_ids.expression_id(*self) - } -} - #[derive(Debug, Default)] pub(super) struct AstIdsBuilder { - expressions_map: FxHashMap, uses_map: FxHashMap, } impl AstIdsBuilder { - /// Adds `expr` to the expression ids map and returns its id. - pub(super) fn record_expression(&mut self, expr: &ast::Expr) -> ScopedExpressionId { - let expression_id = self.expressions_map.len().into(); - - self.expressions_map.insert(expr.into(), expression_id); - - expression_id - } - /// Adds `expr` to the use ids map and returns its id. pub(super) fn record_use(&mut self, expr: impl Into) -> ScopedUseId { let use_id = self.uses_map.len().into(); @@ -188,11 +101,9 @@ impl AstIdsBuilder { } pub(super) fn finish(mut self) -> AstIds { - self.expressions_map.shrink_to_fit(); self.uses_map.shrink_to_fit(); AstIds { - expressions_map: self.expressions_map, uses_map: self.uses_map, } } @@ -219,6 +130,12 @@ pub(crate) mod node_key { } } + impl From<&ast::ExprCall> for ExpressionNodeKey { + fn from(value: &ast::ExprCall) -> Self { + Self(NodeKey::from_node(value)) + } + } + impl From<&ast::Identifier> for ExpressionNodeKey { fn from(value: &ast::Identifier) -> Self { Self(NodeKey::from_node(value)) diff --git a/crates/ty_python_semantic/src/semantic_index/builder.rs b/crates/ty_python_semantic/src/semantic_index/builder.rs index 1c291e45117de0..73bfee7aa2913d 100644 --- a/crates/ty_python_semantic/src/semantic_index/builder.rs +++ b/crates/ty_python_semantic/src/semantic_index/builder.rs @@ -1902,7 +1902,6 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> { self.scopes_by_expression .insert(expr.into(), self.current_scope()); - self.current_ast_ids().record_expression(expr); let node_key = NodeKey::from_node(expr); diff --git a/crates/ty_python_semantic/src/semantic_model.rs b/crates/ty_python_semantic/src/semantic_model.rs index 62de4b60cf840b..cff4a3fe762e22 100644 --- a/crates/ty_python_semantic/src/semantic_model.rs +++ b/crates/ty_python_semantic/src/semantic_model.rs @@ -7,7 +7,6 @@ use ruff_source_file::LineIndex; use crate::Db; use crate::module_name::ModuleName; use crate::module_resolver::{KnownModule, Module, resolve_module}; -use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::place::FileScopeId; use crate::semantic_index::semantic_index; use crate::types::ide_support::all_declarations_and_bindings; @@ -159,8 +158,7 @@ impl HasType for ast::ExprRef<'_> { let file_scope = index.expression_scope_id(*self); let scope = file_scope.to_scope_id(model.db, model.file); - let expression_id = self.scoped_expression_id(model.db, scope); - infer_scope_types(model.db, scope).expression_type(expression_id) + infer_scope_types(model.db, scope).expression_type(*self) } } diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 54b76f562abf20..624b64fdcf7f9d 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -32,7 +32,6 @@ pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; use crate::module_name::ModuleName; use crate::module_resolver::{KnownModule, resolve_module}; use crate::place::{Boundness, Place, PlaceAndQualifiers, imported_symbol}; -use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::definition::Definition; use crate::semantic_index::place::{ScopeId, ScopedPlaceId}; use crate::semantic_index::{imported_modules, place_table, semantic_index}; @@ -141,18 +140,17 @@ fn definition_expression_type<'db>( let index = semantic_index(db, file); let file_scope = index.expression_scope_id(expression); let scope = file_scope.to_scope_id(db, file); - let expr_id = expression.scoped_expression_id(db, scope); if scope == definition.scope(db) { // expression is in the definition scope let inference = infer_definition_types(db, definition); - if let Some(ty) = inference.try_expression_type(expr_id) { + if let Some(ty) = inference.try_expression_type(expression) { ty } else { - infer_deferred_types(db, definition).expression_type(expr_id) + infer_deferred_types(db, definition).expression_type(expression) } } else { // expression is in a type-params sub-scope - infer_scope_types(db, scope).expression_type(expr_id) + infer_scope_types(db, scope).expression_type(expression) } } diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index a6fdebda5ebb14..e457269f2785b8 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -32,7 +32,6 @@ use crate::{ known_module_symbol, place_from_bindings, place_from_declarations, }, semantic_index::{ - ast_ids::HasScopedExpressionId, attribute_assignments, definition::{DefinitionKind, TargetKind}, place::ScopeId, @@ -1867,10 +1866,8 @@ impl<'db> ClassLiteral<'db> { // [.., self.name, ..] = let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = assign - .target(&module) - .scoped_expression_id(db, method_scope); - let inferred_ty = unpacked.expression_type(target_ast_id); + + let inferred_ty = unpacked.expression_type(assign.target(&module)); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } @@ -1896,10 +1893,8 @@ impl<'db> ClassLiteral<'db> { // for .., self.name, .. in : let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = for_stmt - .target(&module) - .scoped_expression_id(db, method_scope); - let inferred_ty = unpacked.expression_type(target_ast_id); + let inferred_ty = + unpacked.expression_type(for_stmt.target(&module)); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } @@ -1927,10 +1922,8 @@ impl<'db> ClassLiteral<'db> { // with as .., self.name, ..: let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = with_item - .target(&module) - .scoped_expression_id(db, method_scope); - let inferred_ty = unpacked.expression_type(target_ast_id); + let inferred_ty = + unpacked.expression_type(with_item.target(&module)); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } @@ -1957,10 +1950,9 @@ impl<'db> ClassLiteral<'db> { // [... for .., self.name, .. in ] let unpacked = infer_unpack_types(db, unpack); - let target_ast_id = comprehension - .target(&module) - .scoped_expression_id(db, unpack.target_scope(db)); - let inferred_ty = unpacked.expression_type(target_ast_id); + + let inferred_ty = + unpacked.expression_type(comprehension.target(&module)); union_of_inferred_types = union_of_inferred_types.add(inferred_ty); } diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 35d421e6da8c12..df7b96595e039d 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -45,6 +45,21 @@ use rustc_hash::{FxHashMap, FxHashSet}; use salsa; use salsa::plumbing::AsId; +use super::context::{InNoTypeCheck, InferContext}; +use super::diagnostic::{ + INVALID_METACLASS, INVALID_OVERLOAD, INVALID_PROTOCOL, SUBCLASS_OF_FINAL_CLASS, + hint_if_stdlib_submodule_exists_on_other_versions, report_attempted_protocol_instantiation, + report_duplicate_bases, report_index_out_of_bounds, report_invalid_exception_caught, + report_invalid_exception_cause, report_invalid_exception_raised, + report_invalid_or_unsupported_base, report_invalid_type_checking_constant, + report_non_subscriptable, report_possibly_unresolved_reference, report_slice_step_size_zero, +}; +use super::generics::LegacyGenericBase; +use super::string_annotation::{ + BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation, +}; +use super::subclass_of::SubclassOfInner; +use super::{ClassBase, NominalInstanceType, add_inferred_python_version_hint_to_diagnostic}; use crate::module_name::{ModuleName, ModuleNameResolutionError}; use crate::module_resolver::resolve_module; use crate::node_key::NodeKey; @@ -54,9 +69,8 @@ use crate::place::{ module_type_implicit_global_declaration, module_type_implicit_global_symbol, place, place_from_bindings, place_from_declarations, typing_extensions_symbol, }; -use crate::semantic_index::ast_ids::{ - HasScopedExpressionId, HasScopedUseId, ScopedExpressionId, ScopedUseId, -}; +use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; +use crate::semantic_index::ast_ids::{HasScopedUseId, ScopedUseId}; use crate::semantic_index::definition::{ AnnotatedAssignmentDefinitionKind, AssignmentDefinitionKind, ComprehensionDefinitionKind, Definition, DefinitionKind, DefinitionNodeKey, DefinitionState, ExceptHandlerDefinitionKind, @@ -110,22 +124,6 @@ use crate::util::diagnostics::format_enumeration; use crate::util::subscript::{PyIndex, PySlice}; use crate::{Db, FxOrderSet, Program}; -use super::context::{InNoTypeCheck, InferContext}; -use super::diagnostic::{ - INVALID_METACLASS, INVALID_OVERLOAD, INVALID_PROTOCOL, SUBCLASS_OF_FINAL_CLASS, - hint_if_stdlib_submodule_exists_on_other_versions, report_attempted_protocol_instantiation, - report_duplicate_bases, report_index_out_of_bounds, report_invalid_exception_caught, - report_invalid_exception_cause, report_invalid_exception_raised, - report_invalid_or_unsupported_base, report_invalid_type_checking_constant, - report_non_subscriptable, report_possibly_unresolved_reference, report_slice_step_size_zero, -}; -use super::generics::LegacyGenericBase; -use super::string_annotation::{ - BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation, -}; -use super::subclass_of::SubclassOfInner; -use super::{ClassBase, NominalInstanceType, add_inferred_python_version_hint_to_diagnostic}; - /// Infer all types for a [`ScopeId`], including all definitions and expressions in that scope. /// Use when checking a scope, or needing to provide a type for an arbitrary expression in the /// scope. @@ -281,12 +279,7 @@ pub(super) fn infer_same_file_expression_type<'db>( parsed: &ParsedModuleRef, ) -> Type<'db> { let inference = infer_expression_types(db, expression); - let scope = expression.scope(db); - inference.expression_type( - expression - .node_ref(db, parsed) - .scoped_expression_id(db, scope), - ) + inference.expression_type(expression.node_ref(db, parsed)) } /// Infers the type of an expression where the expression might come from another file. @@ -337,7 +330,7 @@ pub(super) fn infer_unpack_types<'db>(db: &'db dyn Db, unpack: Unpack<'db>) -> U let _span = tracing::trace_span!("infer_unpack_types", range=?unpack.range(db, &module), ?file) .entered(); - let mut unpacker = Unpacker::new(db, unpack.target_scope(db), unpack.value_scope(db), &module); + let mut unpacker = Unpacker::new(db, unpack.target_scope(db), &module); unpacker.unpack(unpack.target(db, &module), unpack.value(db)); unpacker.finish() } @@ -417,7 +410,7 @@ struct TypeAndRange<'db> { #[derive(Debug, Eq, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) struct TypeInference<'db> { /// The types of every expression in this region. - expressions: FxHashMap>, + expressions: FxHashMap>, /// The types of every binding in this region. bindings: FxHashMap, Type<'db>>, @@ -466,7 +459,7 @@ impl<'db> TypeInference<'db> { } #[track_caller] - pub(crate) fn expression_type(&self, expression: ScopedExpressionId) -> Type<'db> { + pub(crate) fn expression_type(&self, expression: impl Into) -> Type<'db> { self.try_expression_type(expression).expect( "Failed to retrieve the inferred type for an `ast::Expr` node \ passed to `TypeInference::expression_type()`. The `TypeInferenceBuilder` \ @@ -475,9 +468,12 @@ impl<'db> TypeInference<'db> { ) } - pub(crate) fn try_expression_type(&self, expression: ScopedExpressionId) -> Option> { + pub(crate) fn try_expression_type( + &self, + expression: impl Into, + ) -> Option> { self.expressions - .get(&expression) + .get(&expression.into()) .copied() .or(self.cycle_fallback_type) } @@ -738,13 +734,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// this node. #[track_caller] fn expression_type(&self, expr: &ast::Expr) -> Type<'db> { - self.types - .expression_type(expr.scoped_expression_id(self.db(), self.scope())) + self.types.expression_type(expr) } fn try_expression_type(&self, expr: &ast::Expr) -> Option> { - self.types - .try_expression_type(expr.scoped_expression_id(self.db(), self.scope())) + self.types.try_expression_type(expr) } /// Get the type of an expression from any scope in the same file. @@ -762,12 +756,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { fn file_expression_type(&self, expression: &ast::Expr) -> Type<'db> { let file_scope = self.index.expression_scope_id(expression); let expr_scope = file_scope.to_scope_id(self.db(), self.file()); - let expr_id = expression.scoped_expression_id(self.db(), expr_scope); match self.region { InferenceRegion::Scope(scope) if scope == expr_scope => { self.expression_type(expression) } - _ => infer_scope_types(self.db(), expr_scope).expression_type(expr_id), + _ => infer_scope_types(self.db(), expr_scope).expression_type(expression), } } @@ -1954,13 +1947,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { function: &'a ast::StmtFunctionDef, ) -> impl Iterator> + 'a { let definition = self.index.expect_single_definition(function); - let scope = definition.scope(self.db()); + let definition_types = infer_definition_types(self.db(), definition); - function.decorator_list.iter().map(move |decorator| { - definition_types - .expression_type(decorator.expression.scoped_expression_id(self.db(), scope)) - }) + function + .decorator_list + .iter() + .map(move |decorator| definition_types.expression_type(&decorator.expression)) } /// Returns `true` if the current scope is the function body scope of a function overload (that @@ -2759,11 +2752,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { match with_item.target_kind() { TargetKind::Sequence(unpack_position, unpack) => { let unpacked = infer_unpack_types(self.db(), unpack); - let target_ast_id = target.scoped_expression_id(self.db(), self.scope()); if unpack_position == UnpackPosition::First { self.context.extend(unpacked.diagnostics()); } - unpacked.expression_type(target_ast_id) + unpacked.expression_type(target) } TargetKind::Single => { let context_expr_ty = self.infer_standalone_expression(context_expr); @@ -3757,8 +3749,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.context.extend(unpacked.diagnostics()); } - let target_ast_id = target.scoped_expression_id(self.db(), self.scope()); - unpacked.expression_type(target_ast_id) + unpacked.expression_type(target) } TargetKind::Single => { let value_ty = self.infer_standalone_expression(value); @@ -3816,10 +3807,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // But here we explicitly overwrite the type for the overall `self.attr` node with // the annotated type. We do no use `store_expression_type` here, because it checks // that no type has been stored for the expression before. - let expr_id = target.scoped_expression_id(self.db(), self.scope()); self.types .expressions - .insert(expr_id, annotated.inner_type()); + .insert((&**target).into(), annotated.inner_type()); } } @@ -4077,8 +4067,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if unpack_position == UnpackPosition::First { self.context.extend(unpacked.diagnostics()); } - let target_ast_id = target.scoped_expression_id(self.db(), self.scope()); - unpacked.expression_type(target_ast_id) + + unpacked.expression_type(target) } TargetKind::Single => { let iterable_type = self.infer_standalone_expression(iterable); @@ -4628,7 +4618,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // the result from `types` directly because we might be in cycle recovery where // `types.cycle_fallback_type` is `Some(fallback_ty)`, which we can retrieve by // using `expression_type` on `types`: - types.expression_type(expression.scoped_expression_id(self.db(), self.scope())) + types.expression_type(expression) } fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> { @@ -4680,15 +4670,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ty } - fn store_expression_type(&mut self, expression: &impl HasScopedExpressionId, ty: Type<'db>) { + fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { if self.deferred_state.in_string_annotation() { // Avoid storing the type of expressions that are part of a string annotation because // the expression ids don't exists in the semantic index. Instead, we'll store the type // on the string expression itself that represents the annotation. return; } - let expr_id = expression.scoped_expression_id(self.db(), self.scope()); - let previous = self.types.expressions.insert(expr_id, ty); + let previous = self.types.expressions.insert(expression.into(), ty); assert_eq!(previous, None); } @@ -5093,20 +5082,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // because `ScopedExpressionId`s are only meaningful within their own scope, so // we'd add types for random wrong expressions in the current scope if comprehension.is_first() && target.is_name_expr() { - let lookup_scope = self - .index - .parent_scope_id(self.scope().file_scope_id(self.db())) - .expect("A comprehension should never be the top-level scope") - .to_scope_id(self.db(), self.file()); - result.expression_type(iterable.scoped_expression_id(self.db(), lookup_scope)) + result.expression_type(iterable) } else { let scope = self.types.scope; self.types.scope = result.scope; self.extend(result); self.types.scope = scope; - result.expression_type( - iterable.scoped_expression_id(self.db(), expression.scope(self.db())), - ) + result.expression_type(iterable) } }; @@ -5121,9 +5103,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if unpack_position == UnpackPosition::First { self.context.extend(unpacked.diagnostics()); } - let target_ast_id = - target.scoped_expression_id(self.db(), unpack.target_scope(self.db())); - unpacked.expression_type(target_ast_id) + + unpacked.expression_type(target) } TargetKind::Single => { let iterable_type = infer_iterable_type(); @@ -5135,10 +5116,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } }; - self.types.expressions.insert( - target.scoped_expression_id(self.db(), self.scope()), - target_type, - ); + self.types.expressions.insert(target.into(), target_type); self.add_binding(target.into(), definition, target_type); } diff --git a/crates/ty_python_semantic/src/types/narrow.rs b/crates/ty_python_semantic/src/types/narrow.rs index 228f83563ac715..28f18b2dc86530 100644 --- a/crates/ty_python_semantic/src/types/narrow.rs +++ b/crates/ty_python_semantic/src/types/narrow.rs @@ -1,5 +1,4 @@ use crate::Db; -use crate::semantic_index::ast_ids::HasScopedExpressionId; use crate::semantic_index::expression::Expression; use crate::semantic_index::place::{PlaceExpr, PlaceTable, ScopeId, ScopedPlaceId}; use crate::semantic_index::place_table; @@ -687,7 +686,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { // and that requires cross-symbol constraints, which we don't support yet. return None; } - let scope = self.scope(); + let inference = infer_expression_types(self.db, expression); let comparator_tuples = std::iter::once(&**left) @@ -698,10 +697,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let mut last_rhs_ty: Option = None; for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) { - let lhs_ty = last_rhs_ty.unwrap_or_else(|| { - inference.expression_type(left.scoped_expression_id(self.db, scope)) - }); - let rhs_ty = inference.expression_type(right.scoped_expression_id(self.db, scope)); + let lhs_ty = last_rhs_ty.unwrap_or_else(|| inference.expression_type(left)); + let rhs_ty = inference.expression_type(right); last_rhs_ty = Some(rhs_ty); match left { @@ -756,8 +753,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { continue; } - let callable_type = - inference.expression_type(callable.scoped_expression_id(self.db, scope)); + let callable_type = inference.expression_type(&**callable); if callable_type .into_class_literal() @@ -782,11 +778,9 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { expression: Expression<'db>, is_positive: bool, ) -> Option> { - let scope = self.scope(); let inference = infer_expression_types(self.db, expression); - let callable_ty = - inference.expression_type(expr_call.func.scoped_expression_id(self.db, scope)); + let callable_ty = inference.expression_type(&*expr_call.func); // TODO: add support for PEP 604 union types on the right hand side of `isinstance` // and `issubclass`, for example `isinstance(x, str | (int | float))`. @@ -797,8 +791,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { None | Some(KnownFunction::RevealType) ) => { - let return_ty = - inference.expression_type(expr_call.scoped_expression_id(self.db, scope)); + let return_ty = inference.expression_type(expr_call); let (guarded_ty, place) = match return_ty { // TODO: TypeGuard @@ -824,7 +817,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { if function == KnownFunction::HasAttr { let attr = inference - .expression_type(second_arg.scoped_expression_id(self.db, scope)) + .expression_type(second_arg) .into_string_literal()? .value(self.db); @@ -845,8 +838,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { let function = function.into_classinfo_constraint_function()?; - let class_info_ty = - inference.expression_type(second_arg.scoped_expression_id(self.db, scope)); + let class_info_ty = inference.expression_type(second_arg); function .generate_constraint(self.db, class_info_ty) @@ -937,15 +929,12 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> { is_positive: bool, ) -> Option> { let inference = infer_expression_types(self.db, expression); - let scope = self.scope(); let mut sub_constraints = expr_bool_op .values .iter() // filter our arms with statically known truthiness .filter(|expr| { - inference - .expression_type(expr.scoped_expression_id(self.db, scope)) - .bool(self.db) + inference.expression_type(*expr).bool(self.db) != match expr_bool_op.op { BoolOp::And => Truthiness::AlwaysTrue, BoolOp::Or => Truthiness::AlwaysFalse, diff --git a/crates/ty_python_semantic/src/types/unpacker.rs b/crates/ty_python_semantic/src/types/unpacker.rs index b9df27a78f99d0..e3c74dd25ac56b 100644 --- a/crates/ty_python_semantic/src/types/unpacker.rs +++ b/crates/ty_python_semantic/src/types/unpacker.rs @@ -6,7 +6,7 @@ use rustc_hash::FxHashMap; use ruff_python_ast::{self as ast, AnyNodeRef}; use crate::Db; -use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; +use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::place::ScopeId; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleUnpacker}; use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types}; @@ -18,23 +18,18 @@ use super::diagnostic::INVALID_ASSIGNMENT; /// Unpacks the value expression type to their respective targets. pub(crate) struct Unpacker<'db, 'ast> { context: InferContext<'db, 'ast>, - target_scope: ScopeId<'db>, - value_scope: ScopeId<'db>, - targets: FxHashMap>, + targets: FxHashMap>, } impl<'db, 'ast> Unpacker<'db, 'ast> { pub(crate) fn new( db: &'db dyn Db, target_scope: ScopeId<'db>, - value_scope: ScopeId<'db>, module: &'ast ParsedModuleRef, ) -> Self { Self { context: InferContext::new(db, target_scope, module), targets: FxHashMap::default(), - target_scope, - value_scope, } } @@ -53,9 +48,8 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { "Unpacking target must be a list or tuple expression" ); - let value_type = infer_expression_types(self.db(), value.expression()).expression_type( - value.scoped_expression_id(self.db(), self.value_scope, self.module()), - ); + let value_type = infer_expression_types(self.db(), value.expression()) + .expression_type(value.expression().node_ref(self.db(), self.module())); let value_type = match value.kind() { UnpackKind::Assign => { @@ -103,10 +97,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { ) { match target { ast::Expr::Name(_) | ast::Expr::Attribute(_) | ast::Expr::Subscript(_) => { - self.targets.insert( - target.scoped_expression_id(self.db(), self.target_scope), - value_ty, - ); + self.targets.insert(target.into(), value_ty); } ast::Expr::Starred(ast::ExprStarred { value, .. }) => { self.unpack_inner(value, value_expr, value_ty); @@ -208,7 +199,7 @@ impl<'db, 'ast> Unpacker<'db, 'ast> { #[derive(Debug, Default, PartialEq, Eq, salsa::Update, get_size2::GetSize)] pub(crate) struct UnpackResult<'db> { - targets: FxHashMap>, + targets: FxHashMap>, diagnostics: TypeCheckDiagnostics, /// The fallback type for missing expressions. @@ -226,16 +217,19 @@ impl<'db> UnpackResult<'db> { /// May panic if a scoped expression ID is passed in that does not correspond to a sub- /// expression of the target. #[track_caller] - pub(crate) fn expression_type(&self, expr_id: ScopedExpressionId) -> Type<'db> { + pub(crate) fn expression_type(&self, expr_id: impl Into) -> Type<'db> { self.try_expression_type(expr_id).expect( "expression should belong to this `UnpackResult` and \ `Unpacker` should have inferred a type for it", ) } - pub(crate) fn try_expression_type(&self, expr_id: ScopedExpressionId) -> Option> { + pub(crate) fn try_expression_type( + &self, + expr: impl Into, + ) -> Option> { self.targets - .get(&expr_id) + .get(&expr.into()) .copied() .or(self.cycle_fallback_type) } diff --git a/crates/ty_python_semantic/src/unpack.rs b/crates/ty_python_semantic/src/unpack.rs index 3dda4dd2f2336f..8a3a5d3a5758bf 100644 --- a/crates/ty_python_semantic/src/unpack.rs +++ b/crates/ty_python_semantic/src/unpack.rs @@ -5,7 +5,6 @@ use ruff_text_size::{Ranged, TextRange}; use crate::Db; use crate::ast_node_ref::AstNodeRef; -use crate::semantic_index::ast_ids::{HasScopedExpressionId, ScopedExpressionId}; use crate::semantic_index::expression::Expression; use crate::semantic_index::place::{FileScopeId, ScopeId}; @@ -58,16 +57,6 @@ impl<'db> Unpack<'db> { self._target(db).node(parsed) } - /// Returns the scope in which the unpack value expression belongs. - /// - /// The scope in which the target and value expression belongs to are usually the same - /// except in generator expressions and comprehensions (list/dict/set), where the value - /// expression of the first generator is evaluated in the outer scope, while the ones in the subsequent - /// generators are evaluated in the comprehension scope. - pub(crate) fn value_scope(self, db: &'db dyn Db) -> ScopeId<'db> { - self.value_file_scope(db).to_scope_id(db, self.file(db)) - } - /// Returns the scope where the unpack target expression belongs to. pub(crate) fn target_scope(self, db: &'db dyn Db) -> ScopeId<'db> { self.target_file_scope(db).to_scope_id(db, self.file(db)) @@ -98,18 +87,6 @@ impl<'db> UnpackValue<'db> { self.expression } - /// Returns the [`ScopedExpressionId`] of the underlying expression. - pub(crate) fn scoped_expression_id( - self, - db: &'db dyn Db, - scope: ScopeId<'db>, - module: &ParsedModuleRef, - ) -> ScopedExpressionId { - self.expression() - .node_ref(db, module) - .scoped_expression_id(db, scope) - } - /// Returns the expression as an [`AnyNodeRef`]. pub(crate) fn as_any_node_ref<'ast>( self,