Skip to content

Commit 28815f1

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Fix incorrect tracking of "final" Instances (#6763)
* Refine how Literal and Final interact This diff makes three changes: it fixes a bug where we incorrectly track "final" Instances, does some related refactoring, and finally modifies tuple indexing to be aware of literal contexts. Specifically, here is an example of the bug. Note that mypy ignores the mutable nature of `bar`: def expect_3(x: Literal[3]) -> None: ... foo: Final = 3 bar = foo for i in range(10): bar = i # Currently type-check; this PR makes mypy correctly report an error expect_3(bar) To fix this bug, I decided to adjust the variable assignment logic: if the variable is non-final, we now scan the inferred type we try assigning and recursively erase all set `instance.final_value` fields. This change ended up making the `in_final_declaration` field redundant -- after all, we're going to be actively erasing types on non-final assignments anyways. So, I decided to just remove this field. I suspect this change will also result in some nice dividends down the road: defaulting to preserving the underlying literal when inferring expression types will probably make it easier to add more sophisticated literal-related inference down the road. In the process of implementing the above two, I discovered that "nested" Instance types are effectively ignored. So, the following program does not type check, despite the `Final` and despite that tuples are immutable: bar: Final = (3, 2, 1) # 'bar[0] == 3' is always true, but we currently report an error expect_3(bar[0]) This is mildly annoying, and also made it slightly harder for me to verify my changes above, so I decided to modify `visit_index_expr` to also examine the literal context. (Actually, I found I could move this check directly into the 'accept' method instead of special-casing things within `visit_index_expr` and `analyze_var_ref`. But I decided against this approach: the special-casing feels less intrusive, easier to audit, and slightly more efficient.) * Rename 'final_value' field to 'last_known_value' * Adjust one existing test
1 parent 04d9649 commit 28815f1

File tree

17 files changed

+135
-70
lines changed

17 files changed

+135
-70
lines changed

mypy/checker.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from mypy.typevars import fill_typevars, has_no_typevars, fill_typevars_with_any
5454
from mypy.semanal import set_callable_name, refers_to_fullname
5555
from mypy.mro import calculate_mro
56-
from mypy.erasetype import erase_typevars
56+
from mypy.erasetype import erase_typevars, remove_instance_last_known_values
5757
from mypy.expandtype import expand_type, expand_type_by_instance
5858
from mypy.visitor import NodeVisitor
5959
from mypy.join import join_types
@@ -1898,10 +1898,9 @@ def check_assignment(self, lvalue: Lvalue, rvalue: Expression, infer_lvalue_type
18981898
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)
18991899

19001900
if inferred:
1901-
rvalue_type = self.expr_checker.accept(
1902-
rvalue,
1903-
in_final_declaration=inferred.is_final,
1904-
)
1901+
rvalue_type = self.expr_checker.accept(rvalue)
1902+
if not inferred.is_final:
1903+
rvalue_type = remove_instance_last_known_values(rvalue_type)
19051904
self.infer_variable_type(inferred, lvalue, rvalue_type, rvalue)
19061905

19071906
def check_compatibility_all_supers(self, lvalue: RefExpr, lvalue_type: Optional[Type],

mypy/checkexpr.py

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,6 @@ def __init__(self,
141141
self.plugin = plugin
142142
self.type_context = [None]
143143

144-
# Set to 'True' whenever we are checking the expression in some 'Final' declaration.
145-
# For example, if we're checking the "3" in a statement like "var: Final = 3".
146-
#
147-
# This flag changes the type that eventually gets inferred for "var". Instead of
148-
# inferring *just* a 'builtins.int' instance, we infer an instance that keeps track
149-
# of the underlying literal value. See the comments in Instance's constructors for
150-
# more details.
151-
self.in_final_declaration = False
152-
153144
# Temporary overrides for expression types. This is currently
154145
# used by the union math in overloads.
155146
# TODO: refactor this to use a pattern similar to one in
@@ -224,8 +215,8 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
224215
def analyze_var_ref(self, var: Var, context: Context) -> Type:
225216
if var.type:
226217
if isinstance(var.type, Instance):
227-
if self.is_literal_context() and var.type.final_value is not None:
228-
return var.type.final_value
218+
if self.is_literal_context() and var.type.last_known_value is not None:
219+
return var.type.last_known_value
229220
if var.name() in {'True', 'False'}:
230221
return self.infer_literal_expr_type(var.name() == 'True', 'builtins.bool')
231222
return var.type
@@ -1812,15 +1803,13 @@ def infer_literal_expr_type(self, value: LiteralValue, fallback_name: str) -> Ty
18121803
typ = self.named_type(fallback_name)
18131804
if self.is_literal_context():
18141805
return LiteralType(value=value, fallback=typ)
1815-
elif self.in_final_declaration:
1806+
else:
18161807
return typ.copy_modified(final_value=LiteralType(
18171808
value=value,
18181809
fallback=typ,
18191810
line=typ.line,
18201811
column=typ.column,
18211812
))
1822-
else:
1823-
return typ
18241813

18251814
def visit_int_expr(self, e: IntExpr) -> Type:
18261815
"""Type check an integer literal (trivial)."""
@@ -2450,7 +2439,11 @@ def visit_index_expr(self, e: IndexExpr) -> Type:
24502439
It may also represent type application.
24512440
"""
24522441
result = self.visit_index_expr_helper(e)
2453-
return self.narrow_type_from_binder(e, result)
2442+
result = self.narrow_type_from_binder(e, result)
2443+
if (self.is_literal_context() and isinstance(result, Instance)
2444+
and result.last_known_value is not None):
2445+
result = result.last_known_value
2446+
return result
24542447

24552448
def visit_index_expr_helper(self, e: IndexExpr) -> Type:
24562449
if e.analyzed:
@@ -2542,8 +2535,8 @@ def _get_value(self, index: Expression) -> Optional[int]:
25422535
if isinstance(operand, IntExpr):
25432536
return -1 * operand.value
25442537
typ = self.accept(index)
2545-
if isinstance(typ, Instance) and typ.final_value is not None:
2546-
typ = typ.final_value
2538+
if isinstance(typ, Instance) and typ.last_known_value is not None:
2539+
typ = typ.last_known_value
25472540
if isinstance(typ, LiteralType) and isinstance(typ.value, int):
25482541
return typ.value
25492542
return None
@@ -2553,8 +2546,8 @@ def visit_typeddict_index_expr(self, td_type: TypedDictType, index: Expression)
25532546
item_name = index.value
25542547
else:
25552548
typ = self.accept(index)
2556-
if isinstance(typ, Instance) and typ.final_value is not None:
2557-
typ = typ.final_value
2549+
if isinstance(typ, Instance) and typ.last_known_value is not None:
2550+
typ = typ.last_known_value
25582551

25592552
if isinstance(typ, LiteralType) and isinstance(typ.value, str):
25602553
item_name = typ.value
@@ -3253,16 +3246,13 @@ def accept(self,
32533246
type_context: Optional[Type] = None,
32543247
allow_none_return: bool = False,
32553248
always_allow_any: bool = False,
3256-
in_final_declaration: bool = False,
32573249
) -> Type:
32583250
"""Type check a node in the given type context. If allow_none_return
32593251
is True and this expression is a call, allow it to return None. This
32603252
applies only to this expression and not any subexpressions.
32613253
"""
32623254
if node in self.type_overrides:
32633255
return self.type_overrides[node]
3264-
old_in_final_declaration = self.in_final_declaration
3265-
self.in_final_declaration = in_final_declaration
32663256
self.type_context.append(type_context)
32673257
try:
32683258
if allow_none_return and isinstance(node, CallExpr):
@@ -3274,8 +3264,8 @@ def accept(self,
32743264
except Exception as err:
32753265
report_internal_error(err, self.chk.errors.file,
32763266
node.line, self.chk.errors, self.chk.options)
3267+
32773268
self.type_context.pop()
3278-
self.in_final_declaration = old_in_final_declaration
32793269
assert typ is not None
32803270
self.chk.store_type(node, typ)
32813271

mypy/checkmember.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,8 @@ def analyze_member_access(name: str,
101101
msg,
102102
chk=chk)
103103
result = _analyze_member_access(name, typ, mx, override_info)
104-
if in_literal_context and isinstance(result, Instance) and result.final_value is not None:
105-
return result.final_value
104+
if in_literal_context and isinstance(result, Instance) and result.last_known_value is not None:
105+
return result.last_known_value
106106
else:
107107
return result
108108

mypy/erasetype.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,17 @@ def visit_type_var(self, t: TypeVarType) -> Type:
119119
if self.erase_id(t.id):
120120
return self.replacement
121121
return t
122+
123+
124+
def remove_instance_last_known_values(t: Type) -> Type:
125+
return t.accept(LastKnownValueEraser())
126+
127+
128+
class LastKnownValueEraser(TypeTranslator):
129+
"""Removes the Literal[...] type that may be associated with any
130+
Instance types."""
131+
132+
def visit_instance(self, t: Instance) -> Type:
133+
if t.last_known_value:
134+
return t.copy_modified(final_value=None)
135+
return t

mypy/fixup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,8 @@ def visit_instance(self, inst: Instance) -> None:
155155
base.accept(self)
156156
for a in inst.args:
157157
a.accept(self)
158-
if inst.final_value is not None:
159-
inst.final_value.accept(self)
158+
if inst.last_known_value is not None:
159+
inst.last_known_value.accept(self)
160160

161161
def visit_any(self, o: Any) -> None:
162162
pass # Nothing to descend into.

mypy/newsemanal/typeanal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,9 +700,9 @@ def analyze_literal_param(self, idx: int, arg: Type, ctx: Context) -> Optional[L
700700
elif isinstance(arg, (NoneType, LiteralType)):
701701
# Types that we can just add directly to the literal/potential union of literals.
702702
return [arg]
703-
elif isinstance(arg, Instance) and arg.final_value is not None:
703+
elif isinstance(arg, Instance) and arg.last_known_value is not None:
704704
# Types generated from declarations like "var: Final = 4".
705-
return [arg.final_value]
705+
return [arg.last_known_value]
706706
elif isinstance(arg, UnionType):
707707
out = []
708708
for union_arg in arg.items:

mypy/plugins/common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def try_getting_str_literal(expr: Expression, typ: Type) -> Optional[str]:
125125
"""If this expression is a string literal, or if the corresponding type
126126
is something like 'Literal["some string here"]', returns the underlying
127127
string value. Otherwise, returns None."""
128-
if isinstance(typ, Instance) and typ.final_value is not None:
129-
typ = typ.final_value
128+
if isinstance(typ, Instance) and typ.last_known_value is not None:
129+
typ = typ.last_known_value
130130

131131
if isinstance(typ, LiteralType) and typ.fallback.type.fullname() == 'builtins.str':
132132
val = typ.value

mypy/sametypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def visit_instance(self, left: Instance) -> bool:
7979
return (isinstance(self.right, Instance) and
8080
left.type == self.right.type and
8181
is_same_types(left.args, self.right.args) and
82-
left.final_value == self.right.final_value)
82+
left.last_known_value == self.right.last_known_value)
8383

8484
def visit_type_var(self, left: TypeVarType) -> bool:
8585
return (isinstance(self.right, TypeVarType) and

mypy/server/astdiff.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ def visit_instance(self, typ: Instance) -> SnapshotItem:
284284
return ('Instance',
285285
typ.type.fullname(),
286286
snapshot_types(typ.args),
287-
None if typ.final_value is None else snapshot_type(typ.final_value))
287+
None if typ.last_known_value is None else snapshot_type(typ.last_known_value))
288288

289289
def visit_type_var(self, typ: TypeVarType) -> SnapshotItem:
290290
return ('TypeVar',

mypy/server/astmerge.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,8 @@ def visit_instance(self, typ: Instance) -> None:
342342
typ.type = self.fixup(typ.type)
343343
for arg in typ.args:
344344
arg.accept(self)
345-
if typ.final_value:
346-
typ.final_value.accept(self)
345+
if typ.last_known_value:
346+
typ.last_known_value.accept(self)
347347

348348
def visit_any(self, typ: AnyType) -> None:
349349
pass

0 commit comments

Comments
 (0)