Skip to content

Commit d7d5968

Browse files
dcreagerAlexWaygood
authored andcommitted
[ty] Return a tuple spec from the iterator protocol (#19496)
This PR updates our iterator protocol machinery to return a tuple spec describing the elements that are returned, instead of a type. That allows us to track heterogeneous iterators more precisely, and consolidates the logic in unpacking and splatting, which are the two places where we can take advantage of that more precise information. (Other iterator consumers, like `for` loops, have to collapse the iterated elements down to a single type regardless, and we provide a new helper method on `TupleSpec` to perform that summarization.)
1 parent 2a364bb commit d7d5968

File tree

10 files changed

+204
-125
lines changed

10 files changed

+204
-125
lines changed

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,81 @@ def _(args: tuple[int, *tuple[str, ...], int]) -> None:
259259
takes_at_least_two_positional_only(*args) # error: [invalid-argument-type]
260260
```
261261

262+
### String argument
263+
264+
```py
265+
from typing import Literal
266+
267+
def takes_zero() -> None: ...
268+
def takes_one(x: str) -> None: ...
269+
def takes_two(x: str, y: str) -> None: ...
270+
def takes_two_positional_only(x: str, y: str, /) -> None: ...
271+
def takes_two_different(x: int, y: str) -> None: ...
272+
def takes_two_different_positional_only(x: int, y: str, /) -> None: ...
273+
def takes_at_least_zero(*args) -> None: ...
274+
def takes_at_least_one(x: str, *args) -> None: ...
275+
def takes_at_least_two(x: str, y: str, *args) -> None: ...
276+
def takes_at_least_two_positional_only(x: str, y: str, /, *args) -> None: ...
277+
278+
# Test all of the above with a number of different splatted argument types
279+
280+
def _(args: Literal["a"]) -> None:
281+
takes_zero(*args) # error: [too-many-positional-arguments]
282+
takes_one(*args)
283+
takes_two(*args) # error: [missing-argument]
284+
takes_two_positional_only(*args) # error: [missing-argument]
285+
# error: [invalid-argument-type]
286+
# error: [missing-argument]
287+
takes_two_different(*args)
288+
# error: [invalid-argument-type]
289+
# error: [missing-argument]
290+
takes_two_different_positional_only(*args)
291+
takes_at_least_zero(*args)
292+
takes_at_least_one(*args)
293+
takes_at_least_two(*args) # error: [missing-argument]
294+
takes_at_least_two_positional_only(*args) # error: [missing-argument]
295+
296+
def _(args: Literal["ab"]) -> None:
297+
takes_zero(*args) # error: [too-many-positional-arguments]
298+
takes_one(*args) # error: [too-many-positional-arguments]
299+
takes_two(*args)
300+
takes_two_positional_only(*args)
301+
takes_two_different(*args) # error: [invalid-argument-type]
302+
takes_two_different_positional_only(*args) # error: [invalid-argument-type]
303+
takes_at_least_zero(*args)
304+
takes_at_least_one(*args)
305+
takes_at_least_two(*args)
306+
takes_at_least_two_positional_only(*args)
307+
308+
def _(args: Literal["abc"]) -> None:
309+
takes_zero(*args) # error: [too-many-positional-arguments]
310+
takes_one(*args) # error: [too-many-positional-arguments]
311+
takes_two(*args) # error: [too-many-positional-arguments]
312+
takes_two_positional_only(*args) # error: [too-many-positional-arguments]
313+
# error: [invalid-argument-type]
314+
# error: [too-many-positional-arguments]
315+
takes_two_different(*args)
316+
# error: [invalid-argument-type]
317+
# error: [too-many-positional-arguments]
318+
takes_two_different_positional_only(*args)
319+
takes_at_least_zero(*args)
320+
takes_at_least_one(*args)
321+
takes_at_least_two(*args)
322+
takes_at_least_two_positional_only(*args)
323+
324+
def _(args: str) -> None:
325+
takes_zero(*args)
326+
takes_one(*args)
327+
takes_two(*args)
328+
takes_two_positional_only(*args)
329+
takes_two_different(*args) # error: [invalid-argument-type]
330+
takes_two_different_positional_only(*args) # error: [invalid-argument-type]
331+
takes_at_least_zero(*args)
332+
takes_at_least_one(*args)
333+
takes_at_least_two(*args)
334+
takes_at_least_two_positional_only(*args)
335+
```
336+
262337
### Argument expansion regression
263338

264339
This is a regression that was highlighted by the ecosystem check, which shows that we might need to

crates/ty_python_semantic/resources/mdtest/loops/for.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,12 +738,19 @@ def _(flag: bool, flag2: bool):
738738
reveal_type(y) # revealed: bytes | str | int
739739
```
740740

741+
## Empty tuple is iterable
742+
743+
```py
744+
for x in ():
745+
reveal_type(x) # revealed: Never
746+
```
747+
741748
## Never is iterable
742749

743750
```py
744751
from typing_extensions import Never
745752

746753
def f(never: Never):
747754
for x in never:
748-
reveal_type(x) # revealed: Never
755+
reveal_type(x) # revealed: Unknown
749756
```

crates/ty_python_semantic/src/types.rs

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use infer::nearest_enclosing_class;
22
use itertools::{Either, Itertools};
33
use ruff_db::parsed::parsed_module;
44

5+
use std::borrow::Cow;
56
use std::slice::Iter;
67

78
use bitflags::bitflags;
@@ -56,7 +57,7 @@ use crate::types::infer::infer_unpack_types;
5657
use crate::types::mro::{Mro, MroError, MroIterator};
5758
pub(crate) use crate::types::narrow::infer_narrowing_constraint;
5859
use crate::types::signatures::{Parameter, ParameterForm, Parameters, walk_signature};
59-
use crate::types::tuple::TupleType;
60+
use crate::types::tuple::{TupleSpec, TupleType};
6061
pub use crate::util::diagnostics::add_inferred_python_version_hint_to_diagnostic;
6162
use crate::{Db, FxOrderSet, Module, Program};
6263
pub(crate) use class::{ClassLiteral, ClassType, GenericAlias, KnownClass};
@@ -813,13 +814,6 @@ impl<'db> Type<'db> {
813814
.expect("Expected a Type::EnumLiteral variant")
814815
}
815816

816-
pub(crate) const fn into_tuple(self) -> Option<TupleType<'db>> {
817-
match self {
818-
Type::Tuple(tuple_type) => Some(tuple_type),
819-
_ => None,
820-
}
821-
}
822-
823817
/// Turn a class literal (`Type::ClassLiteral` or `Type::GenericAlias`) into a `ClassType`.
824818
/// Since a `ClassType` must be specialized, apply the default specialization to any
825819
/// unspecialized generic class literal.
@@ -4615,35 +4609,50 @@ impl<'db> Type<'db> {
46154609
}
46164610
}
46174611

4618-
/// Returns the element type when iterating over `self`.
4612+
/// Returns a tuple spec describing the elements that are produced when iterating over `self`.
46194613
///
46204614
/// This method should only be used outside of type checking because it omits any errors.
46214615
/// For type checking, use [`try_iterate`](Self::try_iterate) instead.
4622-
fn iterate(self, db: &'db dyn Db) -> Type<'db> {
4616+
fn iterate(self, db: &'db dyn Db) -> Cow<'db, TupleSpec<'db>> {
46234617
self.try_iterate(db)
4624-
.unwrap_or_else(|err| err.fallback_element_type(db))
4618+
.unwrap_or_else(|err| Cow::Owned(TupleSpec::homogeneous(err.fallback_element_type(db))))
46254619
}
46264620

46274621
/// Given the type of an object that is iterated over in some way,
4628-
/// return the type of objects that are yielded by that iteration.
4622+
/// return a tuple spec describing the type of objects that are yielded by that iteration.
46294623
///
4630-
/// E.g., for the following loop, given the type of `x`, infer the type of `y`:
4624+
/// E.g., for the following call, given the type of `x`, infer the types of the values that are
4625+
/// splatted into `y`'s positional arguments:
46314626
/// ```python
4632-
/// for y in x:
4633-
/// pass
4627+
/// y(*x)
46344628
/// ```
4635-
fn try_iterate(self, db: &'db dyn Db) -> Result<Type<'db>, IterationError<'db>> {
4636-
if let Type::Tuple(tuple_type) = self {
4637-
return Ok(UnionType::from_elements(
4638-
db,
4639-
tuple_type.tuple(db).all_elements(),
4640-
));
4641-
}
4642-
4643-
if let Type::GenericAlias(alias) = self {
4644-
if alias.origin(db).is_tuple(db) {
4645-
return Ok(todo_type!("*tuple[] annotations"));
4629+
fn try_iterate(self, db: &'db dyn Db) -> Result<Cow<'db, TupleSpec<'db>>, IterationError<'db>> {
4630+
match self {
4631+
Type::Tuple(tuple_type) => return Ok(Cow::Borrowed(tuple_type.tuple(db))),
4632+
Type::GenericAlias(alias) if alias.origin(db).is_tuple(db) => {
4633+
return Ok(Cow::Owned(TupleSpec::homogeneous(todo_type!(
4634+
"*tuple[] annotations"
4635+
))));
4636+
}
4637+
Type::StringLiteral(string_literal_ty) => {
4638+
// We could go further and deconstruct to an array of `StringLiteral`
4639+
// with each individual character, instead of just an array of
4640+
// `LiteralString`, but there would be a cost and it's not clear that
4641+
// it's worth it.
4642+
return Ok(Cow::Owned(TupleSpec::from_elements(std::iter::repeat_n(
4643+
Type::LiteralString,
4644+
string_literal_ty.python_len(db),
4645+
))));
4646+
}
4647+
Type::Never => {
4648+
// The dunder logic below would have us return `tuple[Never, ...]`, which eagerly
4649+
// simplifies to `tuple[()]`. That will will cause us to emit false positives if we
4650+
// index into the tuple. Using `tuple[Unknown, ...]` avoids these false positives.
4651+
// TODO: Consider removing this special case, and instead hide the indexing
4652+
// diagnostic in unreachable code.
4653+
return Ok(Cow::Owned(TupleSpec::homogeneous(Type::unknown())));
46464654
}
4655+
_ => {}
46474656
}
46484657

46494658
let try_call_dunder_getitem = || {
@@ -4669,12 +4678,14 @@ impl<'db> Type<'db> {
46694678
Ok(iterator) => {
46704679
// `__iter__` is definitely bound and calling it succeeds.
46714680
// See what calling `__next__` on the object returned by `__iter__` gives us...
4672-
try_call_dunder_next_on_iterator(iterator).map_err(|dunder_next_error| {
4673-
IterationError::IterReturnsInvalidIterator {
4674-
iterator,
4675-
dunder_next_error,
4676-
}
4677-
})
4681+
try_call_dunder_next_on_iterator(iterator)
4682+
.map(|ty| Cow::Owned(TupleSpec::homogeneous(ty)))
4683+
.map_err(
4684+
|dunder_next_error| IterationError::IterReturnsInvalidIterator {
4685+
iterator,
4686+
dunder_next_error,
4687+
},
4688+
)
46784689
}
46794690

46804691
// `__iter__` is possibly unbound...
@@ -4692,10 +4703,10 @@ impl<'db> Type<'db> {
46924703
// and the type returned by the `__getitem__` method.
46934704
//
46944705
// No diagnostic is emitted; iteration will always succeed!
4695-
UnionType::from_elements(
4706+
Cow::Owned(TupleSpec::homogeneous(UnionType::from_elements(
46964707
db,
46974708
[dunder_next_return, dunder_getitem_return_type],
4698-
)
4709+
)))
46994710
})
47004711
.map_err(|dunder_getitem_error| {
47014712
IterationError::PossiblyUnboundIterAndGetitemError {
@@ -4718,13 +4729,13 @@ impl<'db> Type<'db> {
47184729
}
47194730

47204731
// There's no `__iter__` method. Try `__getitem__` instead...
4721-
Err(CallDunderError::MethodNotAvailable) => {
4722-
try_call_dunder_getitem().map_err(|dunder_getitem_error| {
4723-
IterationError::UnboundIterAndGetitemError {
4732+
Err(CallDunderError::MethodNotAvailable) => try_call_dunder_getitem()
4733+
.map(|ty| Cow::Owned(TupleSpec::homogeneous(ty)))
4734+
.map_err(
4735+
|dunder_getitem_error| IterationError::UnboundIterAndGetitemError {
47244736
dunder_getitem_error,
4725-
}
4726-
})
4727-
}
4737+
},
4738+
),
47284739
}
47294740
}
47304741

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,10 @@ impl<'a, 'db> CallArguments<'a, 'db> {
5151
ast::ArgOrKeyword::Arg(arg) => match arg {
5252
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
5353
let ty = infer_argument_type(arg, value);
54-
let length = match ty {
55-
Type::Tuple(tuple) => tuple.tuple(db).len(),
56-
// TODO: have `Type::try_iterator` return a tuple spec, and use its
57-
// length as this argument's arity
58-
_ => TupleLength::unknown(),
59-
};
54+
let length = ty
55+
.try_iterate(db)
56+
.map(|tuple| tuple.len())
57+
.unwrap_or(TupleLength::unknown());
6058
(Argument::Variadic(length), Some(ty))
6159
}
6260
_ => (Argument::Positional, None),

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -977,24 +977,14 @@ impl<'db> Bindings<'db> {
977977
// `tuple(range(42))` => `tuple[int, ...]`
978978
// BUT `tuple((1, 2))` => `tuple[Literal[1], Literal[2]]` rather than `tuple[Literal[1, 2], ...]`
979979
if let [Some(argument)] = overload.parameter_types() {
980-
let overridden_return =
981-
argument.into_tuple().map(Type::Tuple).unwrap_or_else(|| {
982-
// Some awkward special handling is required here because of the fact
983-
// that calling `try_iterate()` on `Never` returns `Never`,
984-
// but `tuple[Never, ...]` eagerly simplifies to `tuple[()]`,
985-
// which will cause us to emit false positives if we index into the tuple.
986-
// Using `tuple[Unknown, ...]` avoids these false positives.
987-
let specialization = if argument.is_never() {
988-
Type::unknown()
989-
} else {
990-
argument.try_iterate(db).expect(
991-
"try_iterate() should not fail on a type \
992-
assignable to `Iterable`",
993-
)
994-
};
995-
TupleType::homogeneous(db, specialization)
996-
});
997-
overload.set_return_type(overridden_return);
980+
let tuple_spec = argument.try_iterate(db).expect(
981+
"try_iterate() should not fail on a type \
982+
assignable to `Iterable`",
983+
);
984+
overload.set_return_type(Type::tuple(TupleType::new(
985+
db,
986+
tuple_spec.as_ref(),
987+
)));
998988
}
999989
}
1000990

@@ -2097,14 +2087,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
20972087
// elements. For tuples, we don't have to do anything! For other types, we treat it as
20982088
// an iterator, and create a homogeneous tuple of its output type, since we don't know
20992089
// how many elements the iterator will produce.
2100-
// TODO: update `Type::try_iterate` to return this tuple type for us.
2101-
let argument_types = match argument_type {
2102-
Type::Tuple(tuple) => Cow::Borrowed(tuple.tuple(self.db)),
2103-
_ => {
2104-
let element_type = argument_type.iterate(self.db);
2105-
Cow::Owned(Tuple::homogeneous(element_type))
2106-
}
2107-
};
2090+
let argument_types = argument_type.iterate(self.db);
21082091

21092092
// TODO: When we perform argument expansion during overload resolution, we might need
21102093
// to retry both `match_parameters` _and_ `check_types` for each expansion. Currently

crates/ty_python_semantic/src/types/class.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@ impl<'db> ClassType<'db> {
745745

746746
match specialization {
747747
Some(spec) => {
748+
// TODO: Once we support PEP 646 annotations for `*args` parameters, we can
749+
// use the tuple itself as the argument type.
748750
let tuple = spec.tuple(db);
749751
let tuple_len = tuple.len();
750752

@@ -2262,7 +2264,8 @@ impl<'db> ClassLiteral<'db> {
22622264
index.expression(for_stmt.iterable(&module)),
22632265
);
22642266
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
2265-
let inferred_ty = iterable_ty.iterate(db);
2267+
let inferred_ty =
2268+
iterable_ty.iterate(db).homogeneous_element_type(db);
22662269

22672270
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
22682271
}
@@ -2320,7 +2323,8 @@ impl<'db> ClassLiteral<'db> {
23202323
index.expression(comprehension.iterable(&module)),
23212324
);
23222325
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
2323-
let inferred_ty = iterable_ty.iterate(db);
2326+
let inferred_ty =
2327+
iterable_ty.iterate(db).homogeneous_element_type(db);
23242328

23252329
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
23262330
}

crates/ty_python_semantic/src/types/generics.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ impl<'db> GenericContext<'db> {
194194
db: &'db dyn Db,
195195
tuple: TupleType<'db>,
196196
) -> Specialization<'db> {
197-
let element_type = UnionType::from_elements(db, tuple.tuple(db).all_elements());
197+
let element_type = tuple.tuple(db).homogeneous_element_type(db);
198198
Specialization::new(db, self, Box::from([element_type]), Some(tuple))
199199
}
200200

0 commit comments

Comments
 (0)