Skip to content

Commit 7ead318

Browse files
committed
Add support for yield from
1 parent e7bccbb commit 7ead318

File tree

7 files changed

+237
-31
lines changed

7 files changed

+237
-31
lines changed

crates/ty_python_semantic/resources/mdtest/async.md

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
async def retrieve() -> int:
77
return 42
88

9-
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
10-
119
async def main():
1210
result = await retrieve()
1311

@@ -24,10 +22,102 @@ T = TypeVar("T")
2422
async def persist(x: T) -> T:
2523
return x
2624

27-
reveal_type(persist) # revealed: def persist(x: T) -> CoroutineType[Any, Any, T]
28-
2925
async def f(x: int):
3026
result = await persist(x)
3127

3228
reveal_type(result) # revealed: int
3329
```
30+
31+
## Use cases
32+
33+
### `Future`
34+
35+
```py
36+
import asyncio
37+
import concurrent.futures
38+
39+
def blocking_function() -> int:
40+
return 42
41+
42+
async def main():
43+
loop = asyncio.get_event_loop()
44+
with concurrent.futures.ThreadPoolExecutor() as pool:
45+
result = await loop.run_in_executor(pool, blocking_function)
46+
47+
# TODO: should be `int`
48+
reveal_type(result) # revealed: Unknown
49+
```
50+
51+
### `asyncio.Task`
52+
53+
```py
54+
import asyncio
55+
56+
async def f() -> int:
57+
return 1
58+
59+
async def main():
60+
task = asyncio.create_task(f())
61+
62+
result = await task
63+
64+
# TODO: this should be `int`
65+
reveal_type(result) # revealed: Unknown
66+
```
67+
68+
### `asyncio.gather`
69+
70+
```py
71+
import asyncio
72+
73+
async def task(name: str) -> int:
74+
return len(name)
75+
76+
async def main():
77+
(a, b) = await asyncio.gather(
78+
task("A"),
79+
task("B"),
80+
)
81+
82+
# TODO: these should be `int`
83+
reveal_type(a) # revealed: Unknown
84+
reveal_type(b) # revealed: Unknown
85+
```
86+
87+
## Under the hood
88+
89+
```toml
90+
[environment]
91+
python-version = "3.12" # Use 3.12 to be able to use PEP 695 generics
92+
```
93+
94+
Let's look at the example from the beginning again:
95+
96+
```py
97+
async def retrieve() -> int:
98+
return 42
99+
```
100+
101+
When we look at the signature of this function, we see that it actually returns a `CoroutineType`:
102+
103+
```py
104+
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
105+
```
106+
107+
The expression `await retrieve()` desugars into a call to the `__await__` dunder method on the
108+
`CoroutineType` object, followed by a `yield from`. Let's first see the return type of `__await__`:
109+
110+
```py
111+
reveal_type(retrieve().__await__()) # revealed: Generator[Any, None, int]
112+
```
113+
114+
We can see that this returns a `Generator` that yields `Any`, and eventually returns `int`. For the
115+
final type of the `await` expression, we retrieve that third argument of the `Generator` type:
116+
117+
```py
118+
from typing import Generator
119+
120+
def _():
121+
result = yield from retrieve().__await__()
122+
reveal_type(result) # revealed: int
123+
```
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# `yield` and `yield from`
2+
3+
## Basic `yield` and `yield from`
4+
5+
The type of a `yield` expression is the "send" type of the generator function. The type of a
6+
`yield from` expression is the return type of the inner generator:
7+
8+
```py
9+
from typing import Generator
10+
11+
def inner_generator() -> Generator[int, bytes, str]:
12+
yield 1
13+
yield 2
14+
x = yield 3
15+
16+
# TODO: this should be `bytes`
17+
reveal_type(x) # revealed: @Todo(yield expressions)
18+
19+
return "done"
20+
21+
def outer_generator():
22+
result = yield from inner_generator()
23+
reveal_type(result) # revealed: str
24+
```
25+
26+
## `yield from` with a custom iterable
27+
28+
`yield from` can also be used with custom iterable types. In that case, the type of the `yield from`
29+
expression can not be determined
30+
31+
```py
32+
from typing import Generator, TypeVar, Generic
33+
34+
T = TypeVar("T")
35+
36+
class OnceIterator(Generic[T]):
37+
def __init__(self, value: T):
38+
self.value = value
39+
self.returned = False
40+
41+
def __next__(self) -> T:
42+
if self.returned:
43+
raise StopIteration
44+
45+
self.returned = True
46+
return self.value
47+
48+
class Once(Generic[T]):
49+
def __init__(self, value: T):
50+
self.value = value
51+
52+
def __iter__(self) -> OnceIterator[T]:
53+
return OnceIterator(self.value)
54+
55+
for x in Once("a"):
56+
reveal_type(x) # revealed: str
57+
58+
def generator() -> Generator:
59+
result = yield from Once("a")
60+
61+
# The `StopIteration` exception might have a `value` attribute which the default of `None`,
62+
# or it could have been customized. So we just return `Unknown` here:
63+
reveal_type(result) # revealed: Unknown
64+
```
65+
66+
## Error cases
67+
68+
### Non-iterable type
69+
70+
```py
71+
from typing import Generator
72+
73+
def generator() -> Generator:
74+
yield from 42 # error: [not-iterable] "Object of type `Literal[42]` is not iterable"
75+
```
76+
77+
### Invalid `yield` type
78+
79+
```py
80+
from typing import Generator
81+
82+
# TODO: This should be an error. Claims to yield `int`, but yields `str`.
83+
def invalid_generator() -> Generator[int, None, None]:
84+
yield "not an int" # This should be an `int`
85+
```
86+
87+
### Invalid return type
88+
89+
```py
90+
from typing import Generator
91+
92+
# TODO: should emit an error (does not return `str`)
93+
def invalid_generator1() -> Generator[int, None, str]:
94+
yield 1
95+
96+
# TODO: should emit an error (does not return `int`)
97+
def invalid_generator2() -> Generator[int, None, None]:
98+
yield 1
99+
100+
return "done"
101+
```

crates/ty_python_semantic/resources/mdtest/with/async.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class Manager:
3232
async def test():
3333
async with Manager() as (x, y):
3434
reveal_type(x) # revealed: int
35+
reveal_type(y) # revealed: str
3536
```
3637

3738
## `@asynccontextmanager`

crates/ty_python_semantic/src/types.rs

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4790,35 +4790,48 @@ impl<'db> Type<'db> {
47904790
}
47914791
}
47924792

4793+
/// Similar to [`Self::try_enter`], but for async context managers.
4794+
fn aenter(self, db: &'db dyn Db) -> Type<'db> {
4795+
// TODO: Add proper error handling and rename this method to `try_aenter`.
4796+
self.try_call_dunder(db, "__aenter__", CallArguments::none())
4797+
.map_or(Type::unknown(), |result| {
4798+
result.return_type(db).resolve_await(db)
4799+
})
4800+
}
4801+
4802+
/// Resolve the type of an `await …` expression where `self` is the type of the awaitable.
47934803
fn resolve_await(self, db: &'db dyn Db) -> Type<'db> {
4804+
// TODO: Add proper error handling and rename this method to `try_await`.
47944805
self.try_call_dunder(db, "__await__", CallArguments::none())
47954806
.map_or(Type::unknown(), |result| {
4796-
let generator_ty = result.return_type(db);
4797-
4798-
if let Type::ProtocolInstance(instance) = generator_ty {
4799-
if let Protocol::FromClass(class) = instance.inner {
4800-
if class.is_known(db, KnownClass::Generator) {
4801-
if let Some(specialization) =
4802-
class.class_literal_specialized(db, None).1
4803-
{
4804-
if let [_, _, return_ty] = specialization.types(db) {
4805-
return *return_ty;
4806-
}
4807-
}
4807+
result
4808+
.return_type(db)
4809+
.generator_return_type(db)
4810+
.unwrap_or_else(Type::unknown)
4811+
})
4812+
}
4813+
4814+
/// Get the return type of a `yield from …` expression where `self` is the type of the generator.
4815+
///
4816+
/// This corresponds to the `ReturnT` parameter of the generic `typing.Generator[YieldT, SendT, ReturnT]`
4817+
/// protocol.
4818+
fn generator_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
4819+
// TODO: Ideally, we would first try to upcast `self` to an instance of `Generator` and *then*
4820+
// match on the protocol instance to get the `ReturnType` type parameter.
4821+
4822+
if let Type::ProtocolInstance(instance) = self {
4823+
if let Protocol::FromClass(class) = instance.inner {
4824+
if class.is_known(db, KnownClass::Generator) {
4825+
if let Some(specialization) = class.class_literal_specialized(db, None).1 {
4826+
if let [_, _, return_ty] = specialization.types(db) {
4827+
return Some(*return_ty);
48084828
}
48094829
}
48104830
}
4831+
}
4832+
}
48114833

4812-
Type::unknown()
4813-
})
4814-
}
4815-
4816-
fn aenter(self, db: &'db dyn Db) -> Type<'db> {
4817-
// TODO: Rename this method to `try_aenter` and add error handling
4818-
self.try_call_dunder(db, "__aenter__", CallArguments::none())
4819-
.map_or(Type::unknown(), |result| {
4820-
result.return_type(db).resolve_await(db)
4821-
})
4834+
None
48224835
}
48234836

48244837
/// Given a class literal or non-dynamic SubclassOf type, try calling it (creating an instance)

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6092,8 +6092,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
60926092
err.fallback_element_type(self.db())
60936093
});
60946094

6095-
// TODO get type from `ReturnType` of generator
6096-
todo_type!("Generic `typing.Generator` type")
6095+
iterable_type
6096+
.generator_return_type(self.db())
6097+
.unwrap_or_else(Type::unknown)
60976098
}
60986099

60996100
fn infer_await_expression(&mut self, await_expression: &ast::ExprAwait) -> Type<'db> {

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,14 @@ impl<'db> Signature<'db> {
320320
inherited_generic_context: Option<GenericContext<'db>>,
321321
definition: Definition<'db>,
322322
function_node: &ast::StmtFunctionDef,
323-
is_async_generator: bool,
323+
is_generator: bool,
324324
) -> Self {
325325
let parameters =
326326
Parameters::from_parameters(db, definition, function_node.parameters.as_ref());
327327
let return_ty = function_node.returns.as_ref().map(|returns| {
328328
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref());
329329

330-
if function_node.is_async && !is_async_generator {
330+
if function_node.is_async && !is_generator {
331331
KnownClass::CoroutineType
332332
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
333333
} else {

crates/ty_python_semantic/src/types/unpacker.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::Db;
99
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
1010
use crate::semantic_index::scope::ScopeId;
1111
use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker};
12-
use crate::types::{CallArguments, Type, TypeCheckDiagnostics, infer_expression_types};
12+
use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types};
1313
use crate::unpack::{UnpackKind, UnpackValue};
1414

1515
use super::context::InferContext;

0 commit comments

Comments
 (0)