Skip to content

Commit dc63efa

Browse files
committed
[ty] Support async/await
1 parent 2680f2e commit dc63efa

File tree

8 files changed

+183
-32
lines changed

8 files changed

+183
-32
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# `async` / `await`
2+
3+
## Basic
4+
5+
```py
6+
async def retrieve() -> int:
7+
return 42
8+
9+
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
10+
11+
async def main():
12+
result = await retrieve()
13+
14+
reveal_type(result) # revealed: int
15+
```
16+
17+
## Generic `async` functions
18+
19+
```py
20+
from typing import TypeVar
21+
22+
T = TypeVar("T")
23+
24+
async def persist(x: T) -> T:
25+
return x
26+
27+
reveal_type(persist) # revealed: def persist(x: T) -> CoroutineType[Any, Any, T]
28+
29+
async def f(x: int):
30+
result = await persist(x)
31+
32+
reveal_type(result) # revealed: int
33+
```

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ reveal_type(get_int()) # revealed: int
1515
async def get_int_async() -> int:
1616
return 42
1717

18-
# TODO: we don't yet support `types.CoroutineType`, should be generic `Coroutine[Any, Any, int]`
19-
reveal_type(get_int_async()) # revealed: @Todo(generic types.CoroutineType)
18+
reveal_type(get_int_async()) # revealed: CoroutineType[Any, Any, int]
2019
```
2120

2221
## Generic

crates/ty_python_semantic/resources/mdtest/with/async.md

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,65 @@ class Manager:
1717

1818
async def test():
1919
async with Manager() as f:
20-
reveal_type(f) # revealed: @Todo(async `with` statement)
20+
reveal_type(f) # revealed: Target
21+
```
22+
23+
## `@asynccontextmanager`
24+
25+
```py
26+
from contextlib import asynccontextmanager
27+
from typing import AsyncGenerator
28+
29+
class Session: ...
30+
31+
@asynccontextmanager
32+
async def connect() -> AsyncGenerator[Session]:
33+
yield Session()
34+
35+
# TODO: this should be `() -> _AsyncGeneratorContextManager[Session, None]`
36+
reveal_type(connect) # revealed: (...) -> _AsyncGeneratorContextManager[Unknown, None]
37+
38+
async def main():
39+
async with connect() as session:
40+
# TODO: should be `Session`
41+
reveal_type(session) # revealed: Unknown
42+
```
43+
44+
## `asyncio.timeout`
45+
46+
```toml
47+
[environment]
48+
python-version = "3.11"
49+
```
50+
51+
```py
52+
import asyncio
53+
54+
async def long_running_task():
55+
await asyncio.sleep(5)
56+
57+
async def main():
58+
async with asyncio.timeout(1):
59+
await long_running_task()
60+
```
61+
62+
## `asyncio.TaskGroup`
63+
64+
```toml
65+
[environment]
66+
python-version = "3.11"
67+
```
68+
69+
```py
70+
import asyncio
71+
72+
async def long_running_task():
73+
await asyncio.sleep(5)
74+
75+
async def main():
76+
async with asyncio.TaskGroup() as tg:
77+
# TODO: should be `TaskGroup`
78+
reveal_type(tg) # revealed: Unknown
79+
80+
tg.create_task(long_running_task())
2181
```

crates/ty_python_semantic/src/types.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4790,6 +4790,29 @@ impl<'db> Type<'db> {
47904790
}
47914791
}
47924792

4793+
fn resolve_await(self, db: &'db dyn Db) -> Type<'db> {
4794+
self.try_call_dunder(db, "__await__", CallArguments::none())
4795+
.map_or(Type::unknown(), |result| {
4796+
let generator_ty = result.return_type(db);
4797+
4798+
if let Type::ProtocolInstance(instance) = generator_ty {
4799+
if let Protocol::FromClass(class) = instance.inner {
4800+
if class.is_known(db, KnownClass::Generator) {
4801+
if let Some(specialization) =
4802+
class.class_literal_specialized(db, None).1
4803+
{
4804+
if let [_, _, return_ty] = specialization.types(db) {
4805+
return *return_ty;
4806+
}
4807+
}
4808+
}
4809+
}
4810+
}
4811+
4812+
Type::unknown()
4813+
})
4814+
}
4815+
47934816
/// Given a class literal or non-dynamic SubclassOf type, try calling it (creating an instance)
47944817
/// and return the resulting instance type.
47954818
///

crates/ty_python_semantic/src/types/class.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,10 +2610,13 @@ pub enum KnownClass {
26102610
UnionType,
26112611
GeneratorType,
26122612
AsyncGeneratorType,
2613+
CoroutineType,
26132614
// Typeshed
26142615
NoneType, // Part of `types` for Python >= 3.10
26152616
// Typing
26162617
Any,
2618+
Awaitable,
2619+
Generator,
26172620
Deprecated,
26182621
StdlibAlias,
26192622
SpecialForm,
@@ -2736,6 +2739,9 @@ impl KnownClass {
27362739
| Self::NotImplementedType
27372740
| Self::Staticmethod
27382741
| Self::Classmethod
2742+
| Self::Awaitable
2743+
| Self::Generator
2744+
| Self::CoroutineType
27392745
| Self::Deprecated
27402746
| Self::Field
27412747
| Self::KwOnly
@@ -2771,7 +2777,8 @@ impl KnownClass {
27712777
| Self::Super
27722778
| Self::GenericAlias
27732779
| Self::Deque
2774-
| Self::Bytes => true,
2780+
| Self::Bytes
2781+
| Self::CoroutineType => true,
27752782

27762783
// It doesn't really make sense to ask the question for `@final` types,
27772784
// since these are "more than solid bases". But we'll anyway infer a `@final`
@@ -2807,6 +2814,8 @@ impl KnownClass {
28072814
// with length >2, or anything that is implemented in pure Python, is not a solid base.
28082815
Self::ABCMeta
28092816
| Self::Any
2817+
| Self::Awaitable
2818+
| Self::Generator
28102819
| Self::Enum
28112820
| Self::EnumType
28122821
| Self::Auto
@@ -2855,6 +2864,8 @@ impl KnownClass {
28552864
| KnownClass::ExceptionGroup
28562865
| KnownClass::Staticmethod
28572866
| KnownClass::Classmethod
2867+
| KnownClass::Awaitable
2868+
| KnownClass::Generator
28582869
| KnownClass::Deprecated
28592870
| KnownClass::Super
28602871
| KnownClass::Enum
@@ -2872,6 +2883,7 @@ impl KnownClass {
28722883
| KnownClass::UnionType
28732884
| KnownClass::GeneratorType
28742885
| KnownClass::AsyncGeneratorType
2886+
| KnownClass::CoroutineType
28752887
| KnownClass::NoneType
28762888
| KnownClass::Any
28772889
| KnownClass::StdlibAlias
@@ -2917,7 +2929,11 @@ impl KnownClass {
29172929
/// 2. It's probably more performant.
29182930
const fn is_protocol(self) -> bool {
29192931
match self {
2920-
Self::SupportsIndex | Self::Iterable | Self::Iterator => true,
2932+
Self::SupportsIndex
2933+
| Self::Iterable
2934+
| Self::Iterator
2935+
| Self::Awaitable
2936+
| Self::Generator => true,
29212937

29222938
Self::Any
29232939
| Self::Bool
@@ -2946,6 +2962,7 @@ impl KnownClass {
29462962
| Self::GenericAlias
29472963
| Self::GeneratorType
29482964
| Self::AsyncGeneratorType
2965+
| Self::CoroutineType
29492966
| Self::ModuleType
29502967
| Self::FunctionType
29512968
| Self::MethodType
@@ -3011,6 +3028,8 @@ impl KnownClass {
30113028
Self::ExceptionGroup => "ExceptionGroup",
30123029
Self::Staticmethod => "staticmethod",
30133030
Self::Classmethod => "classmethod",
3031+
Self::Awaitable => "Awaitable",
3032+
Self::Generator => "Generator",
30143033
Self::Deprecated => "deprecated",
30153034
Self::GenericAlias => "GenericAlias",
30163035
Self::ModuleType => "ModuleType",
@@ -3021,6 +3040,7 @@ impl KnownClass {
30213040
Self::WrapperDescriptorType => "WrapperDescriptorType",
30223041
Self::GeneratorType => "GeneratorType",
30233042
Self::AsyncGeneratorType => "AsyncGeneratorType",
3043+
Self::CoroutineType => "CoroutineType",
30243044
Self::NamedTuple => "NamedTuple",
30253045
Self::NoneType => "NoneType",
30263046
Self::SpecialForm => "_SpecialForm",
@@ -3281,11 +3301,14 @@ impl KnownClass {
32813301
| Self::MethodType
32823302
| Self::GeneratorType
32833303
| Self::AsyncGeneratorType
3304+
| Self::CoroutineType
32843305
| Self::MethodWrapperType
32853306
| Self::UnionType
32863307
| Self::WrapperDescriptorType => KnownModule::Types,
32873308
Self::NoneType => KnownModule::Typeshed,
32883309
Self::Any
3310+
| Self::Awaitable
3311+
| Self::Generator
32893312
| Self::SpecialForm
32903313
| Self::TypeVar
32913314
| Self::NamedTuple
@@ -3366,12 +3389,15 @@ impl KnownClass {
33663389
| Self::ExceptionGroup
33673390
| Self::Staticmethod
33683391
| Self::Classmethod
3392+
| Self::Awaitable
3393+
| Self::Generator
33693394
| Self::Deprecated
33703395
| Self::GenericAlias
33713396
| Self::ModuleType
33723397
| Self::FunctionType
33733398
| Self::GeneratorType
33743399
| Self::AsyncGeneratorType
3400+
| Self::CoroutineType
33753401
| Self::MethodType
33763402
| Self::MethodWrapperType
33773403
| Self::WrapperDescriptorType
@@ -3443,6 +3469,7 @@ impl KnownClass {
34433469
| Self::WrapperDescriptorType
34443470
| Self::GeneratorType
34453471
| Self::AsyncGeneratorType
3472+
| Self::CoroutineType
34463473
| Self::SpecialForm
34473474
| Self::ChainMap
34483475
| Self::Counter
@@ -3457,6 +3484,8 @@ impl KnownClass {
34573484
| Self::ExceptionGroup
34583485
| Self::Staticmethod
34593486
| Self::Classmethod
3487+
| Self::Awaitable
3488+
| Self::Generator
34603489
| Self::Deprecated
34613490
| Self::TypeVar
34623491
| Self::ParamSpec
@@ -3513,12 +3542,15 @@ impl KnownClass {
35133542
"ExceptionGroup" => Self::ExceptionGroup,
35143543
"staticmethod" => Self::Staticmethod,
35153544
"classmethod" => Self::Classmethod,
3545+
"Awaitable" => Self::Awaitable,
3546+
"Generator" => Self::Generator,
35163547
"deprecated" => Self::Deprecated,
35173548
"GenericAlias" => Self::GenericAlias,
35183549
"NoneType" => Self::NoneType,
35193550
"ModuleType" => Self::ModuleType,
35203551
"GeneratorType" => Self::GeneratorType,
35213552
"AsyncGeneratorType" => Self::AsyncGeneratorType,
3553+
"CoroutineType" => Self::CoroutineType,
35223554
"FunctionType" => Self::FunctionType,
35233555
"MethodType" => Self::MethodType,
35243556
"UnionType" => Self::UnionType,
@@ -3623,6 +3655,7 @@ impl KnownClass {
36233655
| Self::UnionType
36243656
| Self::GeneratorType
36253657
| Self::AsyncGeneratorType
3658+
| Self::CoroutineType
36263659
| Self::WrapperDescriptorType
36273660
| Self::Field
36283661
| Self::KwOnly
@@ -3642,6 +3675,7 @@ impl KnownClass {
36423675
| Self::Iterable
36433676
| Self::Iterator
36443677
| Self::NewType => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions),
3678+
Self::Awaitable | Self::Generator => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions | KnownModule::Abc),
36453679
Self::Deprecated => matches!(module, KnownModule::Warnings | KnownModule::TypingExtensions),
36463680

36473681
}

crates/ty_python_semantic/src/types/function.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,16 @@ impl<'db> OverloadLiteral<'db> {
341341
GenericContext::from_type_params(db, index, type_params)
342342
});
343343

344+
let index = semantic_index(db, scope.file(db));
345+
let is_generator = scope.file_scope_id(db).is_generator_function(index);
346+
344347
Signature::from_function(
345348
db,
346349
generic_context,
347350
inherited_generic_context,
348351
definition,
349352
function_stmt_node,
353+
is_generator,
350354
)
351355
}
352356

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3169,26 +3169,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
31693169
let context_expr = with_item.context_expr(self.module());
31703170
let target = with_item.target(self.module());
31713171

3172-
let target_ty = if with_item.is_async() {
3173-
let _context_expr_ty = self.infer_standalone_expression(context_expr);
3174-
todo_type!("async `with` statement")
3175-
} else {
3176-
match with_item.target_kind() {
3177-
TargetKind::Sequence(unpack_position, unpack) => {
3178-
let unpacked = infer_unpack_types(self.db(), unpack);
3179-
if unpack_position == UnpackPosition::First {
3180-
self.context.extend(unpacked.diagnostics());
3181-
}
3182-
unpacked.expression_type(target)
3183-
}
3184-
TargetKind::Single => {
3185-
let context_expr_ty = self.infer_standalone_expression(context_expr);
3186-
self.infer_context_expression(
3187-
context_expr,
3188-
context_expr_ty,
3189-
with_item.is_async(),
3190-
)
3172+
let target_ty = match with_item.target_kind() {
3173+
TargetKind::Sequence(unpack_position, unpack) => {
3174+
let unpacked = infer_unpack_types(self.db(), unpack);
3175+
if unpack_position == UnpackPosition::First {
3176+
self.context.extend(unpacked.diagnostics());
31913177
}
3178+
unpacked.expression_type(target)
3179+
}
3180+
TargetKind::Single => {
3181+
let context_expr_ty = self.infer_standalone_expression(context_expr);
3182+
self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async())
31923183
}
31933184
};
31943185

@@ -3208,9 +3199,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
32083199
context_expression_type: Type<'db>,
32093200
is_async: bool,
32103201
) -> Type<'db> {
3211-
// TODO: Handle async with statements (they use `aenter` and `aexit`)
32123202
if is_async {
3213-
return todo_type!("async `with` statement");
3203+
// TODO: proper error handling for `__aenter__`/`__aexit__` calls
3204+
return context_expression_type
3205+
.try_call_dunder(self.db(), "__aenter__", CallArguments::none())
3206+
.map_or(Type::unknown(), |bindings| {
3207+
bindings.return_type(self.db()).resolve_await(self.db())
3208+
});
32143209
}
32153210

32163211
context_expression_type
@@ -6112,8 +6107,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
61126107
node_index: _,
61136108
value,
61146109
} = await_expression;
6115-
self.infer_expression(value);
6116-
todo_type!("generic `typing.Awaitable` type")
6110+
self.infer_expression(value).resolve_await(self.db())
61176111
}
61186112

61196113
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.

0 commit comments

Comments
 (0)