Skip to content

Commit bc6ea68

Browse files
authored
[ty] Add precise iteration and unpacking inference for string literals and bytes literals (#20023)
## Summary Previously we held off from doing this because we weren't sure that it was worth the added complexity cost. But our code has changed in the months since we made that initial decision, and I think the structure of the code is such that it no longer really leads to much added complexity to add precise inference when unpacking a string literal or a bytes literal. The improved inference we gain from this has real benefits to users (see the mypy_primer report), and this PR doesn't appear to have a performance impact. ## Test plan mdtests
1 parent 796819e commit bc6ea68

File tree

4 files changed

+203
-43
lines changed

4 files changed

+203
-43
lines changed

crates/ty_python_semantic/resources/mdtest/exhaustiveness_checking.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,22 @@ def match_non_exhaustive(x: Literal[0, 1, "a"]):
7474

7575
# this diagnostic is correct: the inferred type of `x` is `Literal[1]`
7676
assert_never(x) # error: [type-assertion-failure]
77+
78+
# This is based on real-world code:
79+
# https://github.com/scipy/scipy/blob/99c0ef6af161a4d8157cae5276a20c30b7677c6f/scipy/linalg/tests/test_lapack.py#L147-L171
80+
def exhaustiveness_using_containment_checks():
81+
for norm_str in "Mm1OoIiFfEe":
82+
if norm_str in "FfEe":
83+
return
84+
else:
85+
if norm_str in "Mm":
86+
return
87+
elif norm_str in "1Oo":
88+
return
89+
elif norm_str in "Ii":
90+
return
91+
92+
assert_never(norm_str)
7793
```
7894

7995
## Checks on enum literals

crates/ty_python_semantic/resources/mdtest/loops/for.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,18 @@ def f(never: Never):
755755
reveal_type(x) # revealed: Unknown
756756
```
757757

758+
## Iterating over literals
759+
760+
```py
761+
from typing import Literal
762+
763+
for char in "abcde":
764+
reveal_type(char) # revealed: Literal["a", "b", "c", "d", "e"]
765+
766+
for char in b"abcde":
767+
reveal_type(char) # revealed: Literal[97, 98, 99, 100, 101]
768+
```
769+
758770
## A class literal is iterable if it inherits from `Any`
759771

760772
A class literal can be iterated over if it has `Any` or `Unknown` in its MRO, since the

crates/ty_python_semantic/resources/mdtest/unpacking.md

Lines changed: 124 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -523,8 +523,8 @@ def f(x: MixedTupleSubclass):
523523

524524
```py
525525
a, b = "ab"
526-
reveal_type(a) # revealed: LiteralString
527-
reveal_type(b) # revealed: LiteralString
526+
reveal_type(a) # revealed: Literal["a"]
527+
reveal_type(b) # revealed: Literal["b"]
528528
```
529529

530530
### Uneven unpacking (1)
@@ -570,37 +570,37 @@ reveal_type(d) # revealed: Unknown
570570

571571
```py
572572
(a, *b, c) = "ab"
573-
reveal_type(a) # revealed: LiteralString
573+
reveal_type(a) # revealed: Literal["a"]
574574
reveal_type(b) # revealed: list[Never]
575-
reveal_type(c) # revealed: LiteralString
575+
reveal_type(c) # revealed: Literal["b"]
576576
```
577577

578578
### Starred expression (3)
579579

580580
```py
581581
(a, *b, c) = "abc"
582-
reveal_type(a) # revealed: LiteralString
583-
reveal_type(b) # revealed: list[LiteralString]
584-
reveal_type(c) # revealed: LiteralString
582+
reveal_type(a) # revealed: Literal["a"]
583+
reveal_type(b) # revealed: list[Literal["b"]]
584+
reveal_type(c) # revealed: Literal["c"]
585585
```
586586

587587
### Starred expression (4)
588588

589589
```py
590590
(a, *b, c, d) = "abcdef"
591-
reveal_type(a) # revealed: LiteralString
592-
reveal_type(b) # revealed: list[LiteralString]
593-
reveal_type(c) # revealed: LiteralString
594-
reveal_type(d) # revealed: LiteralString
591+
reveal_type(a) # revealed: Literal["a"]
592+
reveal_type(b) # revealed: list[Literal["b", "c", "d"]]
593+
reveal_type(c) # revealed: Literal["e"]
594+
reveal_type(d) # revealed: Literal["f"]
595595
```
596596

597597
### Starred expression (5)
598598

599599
```py
600600
(a, b, *c) = "abcd"
601-
reveal_type(a) # revealed: LiteralString
602-
reveal_type(b) # revealed: LiteralString
603-
reveal_type(c) # revealed: list[LiteralString]
601+
reveal_type(a) # revealed: Literal["a"]
602+
reveal_type(b) # revealed: Literal["b"]
603+
reveal_type(c) # revealed: list[Literal["c", "d"]]
604604
```
605605

606606
### Starred expression (6)
@@ -650,8 +650,114 @@ reveal_type(b) # revealed: Unknown
650650
```py
651651
(a, b) = "\ud800\udfff"
652652

653+
reveal_type(a) # revealed: Literal["�"]
654+
reveal_type(b) # revealed: Literal["�"]
655+
```
656+
657+
### Very long literal
658+
659+
```py
660+
string = "very long stringgggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg"
661+
662+
a, *b = string
653663
reveal_type(a) # revealed: LiteralString
654-
reveal_type(b) # revealed: LiteralString
664+
reveal_type(b) # revealed: list[LiteralString]
665+
```
666+
667+
## Bytes
668+
669+
### Simple unpacking
670+
671+
```py
672+
a, b = b"ab"
673+
reveal_type(a) # revealed: Literal[97]
674+
reveal_type(b) # revealed: Literal[98]
675+
```
676+
677+
### Uneven unpacking (1)
678+
679+
```py
680+
# error: [invalid-assignment] "Not enough values to unpack: Expected 3"
681+
a, b, c = b"ab"
682+
reveal_type(a) # revealed: Unknown
683+
reveal_type(b) # revealed: Unknown
684+
reveal_type(c) # revealed: Unknown
685+
```
686+
687+
### Uneven unpacking (2)
688+
689+
```py
690+
# error: [invalid-assignment] "Too many values to unpack: Expected 2"
691+
a, b = b"abc"
692+
reveal_type(a) # revealed: Unknown
693+
reveal_type(b) # revealed: Unknown
694+
```
695+
696+
### Starred expression (1)
697+
698+
```py
699+
# error: [invalid-assignment] "Not enough values to unpack: Expected at least 3"
700+
(a, *b, c, d) = b"ab"
701+
reveal_type(a) # revealed: Unknown
702+
reveal_type(b) # revealed: list[Unknown]
703+
reveal_type(c) # revealed: Unknown
704+
reveal_type(d) # revealed: Unknown
705+
```
706+
707+
```py
708+
# error: [invalid-assignment] "Not enough values to unpack: Expected at least 3"
709+
(a, b, *c, d) = b"a"
710+
reveal_type(a) # revealed: Unknown
711+
reveal_type(b) # revealed: Unknown
712+
reveal_type(c) # revealed: list[Unknown]
713+
reveal_type(d) # revealed: Unknown
714+
```
715+
716+
### Starred expression (2)
717+
718+
```py
719+
(a, *b, c) = b"ab"
720+
reveal_type(a) # revealed: Literal[97]
721+
reveal_type(b) # revealed: list[Never]
722+
reveal_type(c) # revealed: Literal[98]
723+
```
724+
725+
### Starred expression (3)
726+
727+
```py
728+
(a, *b, c) = b"abc"
729+
reveal_type(a) # revealed: Literal[97]
730+
reveal_type(b) # revealed: list[Literal[98]]
731+
reveal_type(c) # revealed: Literal[99]
732+
```
733+
734+
### Starred expression (4)
735+
736+
```py
737+
(a, *b, c, d) = b"abcdef"
738+
reveal_type(a) # revealed: Literal[97]
739+
reveal_type(b) # revealed: list[Literal[98, 99, 100]]
740+
reveal_type(c) # revealed: Literal[101]
741+
reveal_type(d) # revealed: Literal[102]
742+
```
743+
744+
### Starred expression (5)
745+
746+
```py
747+
(a, b, *c) = b"abcd"
748+
reveal_type(a) # revealed: Literal[97]
749+
reveal_type(b) # revealed: Literal[98]
750+
reveal_type(c) # revealed: list[Literal[99, 100]]
751+
```
752+
753+
### Very long literal
754+
755+
```py
756+
too_long = b"very long bytes stringggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggggg"
757+
758+
a, *b = too_long
759+
reveal_type(a) # revealed: int
760+
reveal_type(b) # revealed: list[int]
655761
```
656762

657763
## Union
@@ -714,7 +820,7 @@ def _(arg: tuple[int, tuple[str, bytes]] | tuple[tuple[int, bytes], Literal["ab"
714820
a, (b, c) = arg
715821
reveal_type(a) # revealed: int | tuple[int, bytes]
716822
reveal_type(b) # revealed: str
717-
reveal_type(c) # revealed: bytes | LiteralString
823+
reveal_type(c) # revealed: bytes | Literal["b"]
718824
```
719825

720826
### Starred expression
@@ -785,8 +891,8 @@ from typing import Literal
785891

786892
def _(arg: tuple[int, int] | Literal["ab"]):
787893
a, b = arg
788-
reveal_type(a) # revealed: int | LiteralString
789-
reveal_type(b) # revealed: int | LiteralString
894+
reveal_type(a) # revealed: int | Literal["a"]
895+
reveal_type(b) # revealed: int | Literal["b"]
790896
```
791897

792898
### Custom iterator (1)

crates/ty_python_semantic/src/types.rs

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4917,6 +4917,12 @@ impl<'db> Type<'db> {
49174917
db: &'db dyn Db,
49184918
mode: EvaluationMode,
49194919
) -> Result<Cow<'db, TupleSpec<'db>>, IterationError<'db>> {
4920+
// We will not infer precise heterogeneous tuple specs for literals with lengths above this threshold.
4921+
// The threshold here is somewhat arbitrary and conservative; it could be increased if needed.
4922+
// However, it's probably very rare to need heterogeneous unpacking inference for long string literals
4923+
// or bytes literals, and creating long heterogeneous tuple specs has a performance cost.
4924+
const MAX_TUPLE_LENGTH: usize = 128;
4925+
49204926
if mode.is_async() {
49214927
let try_call_dunder_anext_on_iterator = |iterator: Type<'db>| -> Result<
49224928
Result<Type<'db>, AwaitError<'db>>,
@@ -4972,52 +4978,66 @@ impl<'db> Type<'db> {
49724978
};
49734979
}
49744980

4975-
match self {
4976-
Type::NominalInstance(nominal) => {
4977-
if let Some(spec) = nominal.tuple_spec(db) {
4978-
return Ok(spec);
4979-
}
4980-
}
4981+
let special_case = match self {
4982+
Type::NominalInstance(nominal) => nominal.tuple_spec(db),
49814983
Type::GenericAlias(alias) if alias.origin(db).is_tuple(db) => {
4982-
return Ok(Cow::Owned(TupleSpec::homogeneous(todo_type!(
4984+
Some(Cow::Owned(TupleSpec::homogeneous(todo_type!(
49834985
"*tuple[] annotations"
4984-
))));
4986+
))))
49854987
}
49864988
Type::StringLiteral(string_literal_ty) => {
4987-
// We could go further and deconstruct to an array of `StringLiteral`
4988-
// with each individual character, instead of just an array of
4989-
// `LiteralString`, but there would be a cost and it's not clear that
4990-
// it's worth it.
4991-
return Ok(Cow::Owned(TupleSpec::heterogeneous(std::iter::repeat_n(
4992-
Type::LiteralString,
4993-
string_literal_ty.python_len(db),
4994-
))));
4989+
let string_literal = string_literal_ty.value(db);
4990+
let spec = if string_literal.len() < MAX_TUPLE_LENGTH {
4991+
TupleSpec::heterogeneous(
4992+
string_literal
4993+
.chars()
4994+
.map(|c| Type::string_literal(db, &c.to_string())),
4995+
)
4996+
} else {
4997+
TupleSpec::homogeneous(Type::LiteralString)
4998+
};
4999+
Some(Cow::Owned(spec))
5000+
}
5001+
Type::BytesLiteral(bytes) => {
5002+
let bytes_literal = bytes.value(db);
5003+
let spec = if bytes_literal.len() < MAX_TUPLE_LENGTH {
5004+
TupleSpec::heterogeneous(
5005+
bytes_literal
5006+
.iter()
5007+
.map(|b| Type::IntLiteral(i64::from(*b))),
5008+
)
5009+
} else {
5010+
TupleSpec::homogeneous(KnownClass::Int.to_instance(db))
5011+
};
5012+
Some(Cow::Owned(spec))
49955013
}
49965014
Type::Never => {
49975015
// The dunder logic below would have us return `tuple[Never, ...]`, which eagerly
49985016
// simplifies to `tuple[()]`. That will will cause us to emit false positives if we
49995017
// index into the tuple. Using `tuple[Unknown, ...]` avoids these false positives.
50005018
// TODO: Consider removing this special case, and instead hide the indexing
50015019
// diagnostic in unreachable code.
5002-
return Ok(Cow::Owned(TupleSpec::homogeneous(Type::unknown())));
5020+
Some(Cow::Owned(TupleSpec::homogeneous(Type::unknown())))
50035021
}
50045022
Type::TypeAlias(alias) => {
5005-
return alias.value_type(db).try_iterate_with_mode(db, mode);
5023+
Some(alias.value_type(db).try_iterate_with_mode(db, mode)?)
50065024
}
50075025
Type::NonInferableTypeVar(tvar) => match tvar.typevar(db).bound_or_constraints(db) {
50085026
Some(TypeVarBoundOrConstraints::UpperBound(bound)) => {
5009-
return bound.try_iterate_with_mode(db, mode);
5027+
Some(bound.try_iterate_with_mode(db, mode)?)
50105028
}
50115029
// TODO: could we create a "union of tuple specs"...?
50125030
// (Same question applies to the `Type::Union()` branch lower down)
5013-
Some(TypeVarBoundOrConstraints::Constraints(_)) | None => {}
5031+
Some(TypeVarBoundOrConstraints::Constraints(_)) | None => None
50145032
},
50155033
Type::TypeVar(_) => unreachable!(
50165034
"should not be able to iterate over type variable {} in inferable position",
50175035
self.display(db)
50185036
),
5019-
Type::Dynamic(_)
5020-
| Type::FunctionLiteral(_)
5037+
// N.B. These special cases aren't strictly necessary, they're just obvious optimizations
5038+
Type::LiteralString | Type::Dynamic(_) => Some(Cow::Owned(TupleSpec::homogeneous(self))),
5039+
5040+
Type::FunctionLiteral(_)
50215041
| Type::GenericAlias(_)
50225042
| Type::BoundMethod(_)
50235043
| Type::MethodWrapper(_)
@@ -5026,6 +5046,10 @@ impl<'db> Type<'db> {
50265046
| Type::DataclassTransformer(_)
50275047
| Type::Callable(_)
50285048
| Type::ModuleLiteral(_)
5049+
// We could infer a precise tuple spec for enum classes with members,
5050+
// but it's not clear whether that's worth the added complexity:
5051+
// you'd have to check that `EnumMeta.__iter__` is not overridden for it to be sound
5052+
// (enums can have `EnumMeta` subclasses as their metaclasses).
50295053
| Type::ClassLiteral(_)
50305054
| Type::SubclassOf(_)
50315055
| Type::ProtocolInstance(_)
@@ -5039,11 +5063,13 @@ impl<'db> Type<'db> {
50395063
| Type::IntLiteral(_)
50405064
| Type::BooleanLiteral(_)
50415065
| Type::EnumLiteral(_)
5042-
| Type::LiteralString
5043-
| Type::BytesLiteral(_)
50445066
| Type::BoundSuper(_)
50455067
| Type::TypeIs(_)
5046-
| Type::TypedDict(_) => {}
5068+
| Type::TypedDict(_) => None
5069+
};
5070+
5071+
if let Some(special_case) = special_case {
5072+
return Ok(special_case);
50475073
}
50485074

50495075
let try_call_dunder_getitem = || {

0 commit comments

Comments
 (0)