Skip to content

Commit 4ecf1d2

Browse files
authored
[ty] Support async/await, async with and yield from (#19595)
## Summary - Add support for the return types of `async` functions - Add type inference for `await` expressions - Add support for `async with` / async context managers - Add support for `yield from` expressions This PR is generally lacking proper error handling in some cases (e.g. illegal `__await__` attributes). I'm planning to work on this in a follow-up. part of astral-sh/ty#151 closes astral-sh/ty#736 ## Ecosystem There are a lot of true positives on `prefect` which look similar to: ```diff prefect (https://github.com/PrefectHQ/prefect) + src/integrations/prefect-aws/tests/workers/test_ecs_worker.py:406:12: error[unresolved-attribute] Type `str` has no attribute `status_code` ``` This is due to a wrong return type annotation [here](https://github.com/PrefectHQ/prefect/blob/e926b8c4c114e74533e7bd75d83e7f205c774645/src/integrations/prefect-aws/tests/workers/test_ecs_worker.py#L355-L391). ```diff mitmproxy (https://github.com/mitmproxy/mitmproxy) + test/mitmproxy/addons/test_clientplayback.py:18:1: error[invalid-argument-type] Argument to function `asynccontextmanager` is incorrect: Expected `(...) -> AsyncIterator[Unknown]`, found `def tcp_server(handle_conn, **server_args) -> Unknown | tuple[str, int]` ``` [This](https://github.com/mitmproxy/mitmproxy/blob/a4d794c59a27472d193a592d8037505a1cf6ae93/test/mitmproxy/addons/test_clientplayback.py#L18-L19) is a true positive. That function should return `AsyncIterator[Address]`, not `Address`. I looked through almost all of the other new diagnostics and they all look like known problems or true positives. ## Typing conformance The typing conformance diff looks good. ## Test Plan New Markdown tests
1 parent c5ac998 commit 4ecf1d2

File tree

12 files changed

+472
-46
lines changed

12 files changed

+472
-46
lines changed
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# `async` / `await`
2+
3+
## Basic
4+
5+
```py
6+
async def retrieve() -> int:
7+
return 42
8+
9+
async def main():
10+
result = await retrieve()
11+
12+
reveal_type(result) # revealed: int
13+
```
14+
15+
## Generic `async` functions
16+
17+
```py
18+
from typing import TypeVar
19+
20+
T = TypeVar("T")
21+
22+
async def persist(x: T) -> T:
23+
return x
24+
25+
async def f(x: int):
26+
result = await persist(x)
27+
28+
reveal_type(result) # revealed: int
29+
```
30+
31+
## Use cases
32+
33+
### `Future`
34+
35+
```py
36+
import asyncio
37+
import concurrent.futures
38+
39+
def blocking_function() -> int:
40+
return 42
41+
42+
async def main():
43+
loop = asyncio.get_event_loop()
44+
with concurrent.futures.ThreadPoolExecutor() as pool:
45+
result = await loop.run_in_executor(pool, blocking_function)
46+
47+
# TODO: should be `int`
48+
reveal_type(result) # revealed: Unknown
49+
```
50+
51+
### `asyncio.Task`
52+
53+
```py
54+
import asyncio
55+
56+
async def f() -> int:
57+
return 1
58+
59+
async def main():
60+
task = asyncio.create_task(f())
61+
62+
result = await task
63+
64+
# TODO: this should be `int`
65+
reveal_type(result) # revealed: Unknown
66+
```
67+
68+
### `asyncio.gather`
69+
70+
```py
71+
import asyncio
72+
73+
async def task(name: str) -> int:
74+
return len(name)
75+
76+
async def main():
77+
(a, b) = await asyncio.gather(
78+
task("A"),
79+
task("B"),
80+
)
81+
82+
# TODO: these should be `int`
83+
reveal_type(a) # revealed: Unknown
84+
reveal_type(b) # revealed: Unknown
85+
```
86+
87+
## Under the hood
88+
89+
```toml
90+
[environment]
91+
python-version = "3.12" # Use 3.12 to be able to use PEP 695 generics
92+
```
93+
94+
Let's look at the example from the beginning again:
95+
96+
```py
97+
async def retrieve() -> int:
98+
return 42
99+
```
100+
101+
When we look at the signature of this function, we see that it actually returns a `CoroutineType`:
102+
103+
```py
104+
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
105+
```
106+
107+
The expression `await retrieve()` desugars into a call to the `__await__` dunder method on the
108+
`CoroutineType` object, followed by a `yield from`. Let's first see the return type of `__await__`:
109+
110+
```py
111+
reveal_type(retrieve().__await__()) # revealed: Generator[Any, None, int]
112+
```
113+
114+
We can see that this returns a `Generator` that yields `Any`, and eventually returns `int`. For the
115+
final type of the `await` expression, we retrieve that third argument of the `Generator` type:
116+
117+
```py
118+
from typing import Generator
119+
120+
def _():
121+
result = yield from retrieve().__await__()
122+
reveal_type(result) # revealed: int
123+
```

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ reveal_type(get_int()) # revealed: int
1515
async def get_int_async() -> int:
1616
return 42
1717

18-
# TODO: we don't yet support `types.CoroutineType`, should be generic `Coroutine[Any, Any, int]`
19-
reveal_type(get_int_async()) # revealed: @Todo(generic types.CoroutineType)
18+
reveal_type(get_int_async()) # revealed: CoroutineType[Any, Any, int]
2019
```
2120

2221
## Generic
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# `yield` and `yield from`
2+
3+
## Basic `yield` and `yield from`
4+
5+
The type of a `yield` expression is the "send" type of the generator function. The type of a
6+
`yield from` expression is the return type of the inner generator:
7+
8+
```py
9+
from typing import Generator
10+
11+
def inner_generator() -> Generator[int, bytes, str]:
12+
yield 1
13+
yield 2
14+
x = yield 3
15+
16+
# TODO: this should be `bytes`
17+
reveal_type(x) # revealed: @Todo(yield expressions)
18+
19+
return "done"
20+
21+
def outer_generator():
22+
result = yield from inner_generator()
23+
reveal_type(result) # revealed: str
24+
```
25+
26+
## `yield from` with a custom iterable
27+
28+
`yield from` can also be used with custom iterable types. In that case, the type of the `yield from`
29+
expression can not be determined
30+
31+
```py
32+
from typing import Generator, TypeVar, Generic
33+
34+
T = TypeVar("T")
35+
36+
class OnceIterator(Generic[T]):
37+
def __init__(self, value: T):
38+
self.value = value
39+
self.returned = False
40+
41+
def __next__(self) -> T:
42+
if self.returned:
43+
raise StopIteration(42)
44+
45+
self.returned = True
46+
return self.value
47+
48+
class Once(Generic[T]):
49+
def __init__(self, value: T):
50+
self.value = value
51+
52+
def __iter__(self) -> OnceIterator[T]:
53+
return OnceIterator(self.value)
54+
55+
for x in Once("a"):
56+
reveal_type(x) # revealed: str
57+
58+
def generator() -> Generator:
59+
result = yield from Once("a")
60+
61+
# At runtime, the value of `result` will be the `.value` attribute of the `StopIteration`
62+
# error raised by `OnceIterator` to signal to the interpreter that the iterator has been
63+
# exhausted. Here that will always be 42, but this information cannot be captured in the
64+
# signature of `OnceIterator.__next__`, since exceptions lie outside the type signature.
65+
# We therefore just infer `Unknown` here.
66+
#
67+
# If the `StopIteration` error in `OnceIterator.__next__` had been simply `raise StopIteration`
68+
# (the more common case), then the `.value` attribute of the `StopIteration` instance
69+
# would default to `None`.
70+
reveal_type(result) # revealed: Unknown
71+
```
72+
73+
## `yield from` with a generator that return `types.GeneratorType`
74+
75+
`types.GeneratorType` is a nominal type that implements the `typing.Generator` protocol:
76+
77+
```py
78+
from types import GeneratorType
79+
80+
def inner_generator() -> GeneratorType[int, bytes, str]:
81+
yield 1
82+
yield 2
83+
x = yield 3
84+
85+
# TODO: this should be `bytes`
86+
reveal_type(x) # revealed: @Todo(yield expressions)
87+
88+
return "done"
89+
90+
def outer_generator():
91+
result = yield from inner_generator()
92+
reveal_type(result) # revealed: str
93+
```
94+
95+
## Error cases
96+
97+
### Non-iterable type
98+
99+
```py
100+
from typing import Generator
101+
102+
def generator() -> Generator:
103+
yield from 42 # error: [not-iterable] "Object of type `Literal[42]` is not iterable"
104+
```
105+
106+
### Invalid `yield` type
107+
108+
```py
109+
from typing import Generator
110+
111+
# TODO: This should be an error. Claims to yield `int`, but yields `str`.
112+
def invalid_generator() -> Generator[int, None, None]:
113+
yield "not an int" # This should be an `int`
114+
```
115+
116+
### Invalid return type
117+
118+
```py
119+
from typing import Generator
120+
121+
# TODO: should emit an error (does not return `str`)
122+
def invalid_generator1() -> Generator[int, None, str]:
123+
yield 1
124+
125+
# TODO: should emit an error (does not return `int`)
126+
def invalid_generator2() -> Generator[int, None, None]:
127+
yield 1
128+
129+
return "done"
130+
```

crates/ty_python_semantic/resources/mdtest/with/async.md

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,80 @@ class Manager:
1717

1818
async def test():
1919
async with Manager() as f:
20-
reveal_type(f) # revealed: @Todo(async `with` statement)
20+
reveal_type(f) # revealed: Target
21+
```
22+
23+
## Multiple targets
24+
25+
```py
26+
class Manager:
27+
async def __aenter__(self) -> tuple[int, str]:
28+
return 42, "hello"
29+
30+
async def __aexit__(self, exc_type, exc_value, traceback): ...
31+
32+
async def test():
33+
async with Manager() as (x, y):
34+
reveal_type(x) # revealed: int
35+
reveal_type(y) # revealed: str
36+
```
37+
38+
## `@asynccontextmanager`
39+
40+
```py
41+
from contextlib import asynccontextmanager
42+
from typing import AsyncGenerator
43+
44+
class Session: ...
45+
46+
@asynccontextmanager
47+
async def connect() -> AsyncGenerator[Session]:
48+
yield Session()
49+
50+
# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]`
51+
reveal_type(connect) # revealed: (...) -> _AsyncGeneratorContextManager[Unknown, None]
52+
53+
async def main():
54+
async with connect() as session:
55+
# TODO: should be `Session`
56+
reveal_type(session) # revealed: Unknown
57+
```
58+
59+
## `asyncio.timeout`
60+
61+
```toml
62+
[environment]
63+
python-version = "3.11"
64+
```
65+
66+
```py
67+
import asyncio
68+
69+
async def long_running_task():
70+
await asyncio.sleep(5)
71+
72+
async def main():
73+
async with asyncio.timeout(1):
74+
await long_running_task()
75+
```
76+
77+
## `asyncio.TaskGroup`
78+
79+
```toml
80+
[environment]
81+
python-version = "3.11"
82+
```
83+
84+
```py
85+
import asyncio
86+
87+
async def long_running_task():
88+
await asyncio.sleep(5)
89+
90+
async def main():
91+
async with asyncio.TaskGroup() as tg:
92+
# TODO: should be `TaskGroup`
93+
reveal_type(tg) # revealed: Unknown
94+
95+
tg.create_task(long_running_task())
2196
```

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2805,7 +2805,9 @@ impl<'ast> Unpackable<'ast> {
28052805
match self {
28062806
Unpackable::Assign(_) => UnpackKind::Assign,
28072807
Unpackable::For(_) | Unpackable::Comprehension { .. } => UnpackKind::Iterable,
2808-
Unpackable::WithItem { .. } => UnpackKind::ContextManager,
2808+
Unpackable::WithItem { is_async, .. } => UnpackKind::ContextManager {
2809+
is_async: *is_async,
2810+
},
28092811
}
28102812
}
28112813

0 commit comments

Comments
 (0)