Skip to content

Commit 6dda1bf

Browse files
committed
[ty] Support async/await
1 parent 8c0743d commit 6dda1bf

File tree

8 files changed

+137
-13
lines changed

8 files changed

+137
-13
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# `async` / `await`
2+
3+
## Basic
4+
5+
```py
6+
async def retrieve() -> int:
7+
return 42
8+
9+
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
10+
11+
async def main():
12+
result = await retrieve()
13+
14+
reveal_type(result) # revealed: int
15+
```

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ reveal_type(get_int()) # revealed: int
1515
async def get_int_async() -> int:
1616
return 42
1717

18-
# TODO: we don't yet support `types.CoroutineType`, should be generic `Coroutine[Any, Any, int]`
19-
reveal_type(get_int_async()) # revealed: @Todo(generic types.CoroutineType)
18+
reveal_type(get_int_async()) # revealed: CoroutineType[Any, Any, int]
2019
```
2120

2221
## Generic

crates/ty_python_semantic/resources/mdtest/function/return_type.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ def f(cond: bool) -> str:
433433

434434
<!-- snapshot-diagnostics -->
435435

436+
## Synchronous
437+
436438
A function with a `yield` or `yield from` expression anywhere in its body is a
437439
[generator function](https://docs.python.org/3/glossary.html#term-generator). A generator function
438440
implicitly returns an instance of `types.GeneratorType` even if it does not contain any `return`
@@ -461,6 +463,8 @@ def j() -> str: # error: [invalid-return-type]
461463
yield 42
462464
```
463465

466+
## Asynchronous
467+
464468
If it is an `async` function with a `yield` statement in its body, it is an
465469
[asynchronous generator function](https://docs.python.org/3/glossary.html#term-asynchronous-generator).
466470
An asynchronous generator function implicitly returns an instance of `types.AsyncGeneratorType` even

crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Generator_functions_(d9ed06b61b14fd4c).snap

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
2828
14 | yield 42
2929
15 |
3030
16 | def i2() -> typing.Generator:
31-
17 | yield from i()
31+
17 | yield from reveal_type(i()) # error: [not-iterable]
3232
18 |
3333
19 | def j() -> str: # error: [invalid-return-type]
3434
20 | yield 42
@@ -53,11 +53,54 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
5353

5454
# Diagnostics
5555

56+
```
57+
error[not-iterable]: Object of type `Iterable[Unknown] | CoroutineType[Any, Any, AsyncIterable[Unknown]]` may not be iterable
58+
--> src/mdtest_snippet.py:17:16
59+
|
60+
16 | def i2() -> typing.Generator:
61+
17 | yield from reveal_type(i()) # error: [not-iterable]
62+
| ^^^^^^^^^^^^^^^^
63+
18 |
64+
19 | def j() -> str: # error: [invalid-return-type]
65+
|
66+
info: It may not have an `__iter__` method and it doesn't have a `__getitem__` method
67+
info: rule `not-iterable` is enabled by default
68+
69+
```
70+
71+
```
72+
warning[undefined-reveal]: `reveal_type` used without importing it
73+
--> src/mdtest_snippet.py:17:16
74+
|
75+
16 | def i2() -> typing.Generator:
76+
17 | yield from reveal_type(i()) # error: [not-iterable]
77+
| ^^^^^^^^^^^
78+
18 |
79+
19 | def j() -> str: # error: [invalid-return-type]
80+
|
81+
info: This is allowed for debugging convenience but will fail at runtime
82+
info: rule `undefined-reveal` is enabled by default
83+
84+
```
85+
86+
```
87+
info[revealed-type]: Revealed type
88+
--> src/mdtest_snippet.py:17:28
89+
|
90+
16 | def i2() -> typing.Generator:
91+
17 | yield from reveal_type(i()) # error: [not-iterable]
92+
| ^^^ `Iterable[Unknown] | CoroutineType[Any, Any, AsyncIterable[Unknown]]`
93+
18 |
94+
19 | def j() -> str: # error: [invalid-return-type]
95+
|
96+
97+
```
98+
5699
```
57100
error[invalid-return-type]: Return type does not match returned value
58101
--> src/mdtest_snippet.py:19:12
59102
|
60-
17 | yield from i()
103+
17 | yield from reveal_type(i()) # error: [not-iterable]
61104
18 |
62105
19 | def j() -> str: # error: [invalid-return-type]
63106
| ^^^ expected `str`, found `types.GeneratorType`

crates/ty_python_semantic/src/types/class.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,10 +2610,13 @@ pub enum KnownClass {
26102610
UnionType,
26112611
GeneratorType,
26122612
AsyncGeneratorType,
2613+
CoroutineType,
26132614
// Typeshed
26142615
NoneType, // Part of `types` for Python >= 3.10
26152616
// Typing
26162617
Any,
2618+
Awaitable,
2619+
Generator,
26172620
Deprecated,
26182621
StdlibAlias,
26192622
SpecialForm,
@@ -2736,6 +2739,9 @@ impl KnownClass {
27362739
| Self::NotImplementedType
27372740
| Self::Staticmethod
27382741
| Self::Classmethod
2742+
| Self::Awaitable
2743+
| Self::Generator
2744+
| Self::CoroutineType
27392745
| Self::Deprecated
27402746
| Self::Field
27412747
| Self::KwOnly
@@ -2771,7 +2777,8 @@ impl KnownClass {
27712777
| Self::Super
27722778
| Self::GenericAlias
27732779
| Self::Deque
2774-
| Self::Bytes => true,
2780+
| Self::Bytes
2781+
| Self::CoroutineType => true,
27752782

27762783
// It doesn't really make sense to ask the question for `@final` types,
27772784
// since these are "more than solid bases". But we'll anyway infer a `@final`
@@ -2807,6 +2814,8 @@ impl KnownClass {
28072814
// with length >2, or anything that is implemented in pure Python, is not a solid base.
28082815
Self::ABCMeta
28092816
| Self::Any
2817+
| Self::Awaitable
2818+
| Self::Generator
28102819
| Self::Enum
28112820
| Self::EnumType
28122821
| Self::Auto
@@ -2855,6 +2864,8 @@ impl KnownClass {
28552864
| KnownClass::ExceptionGroup
28562865
| KnownClass::Staticmethod
28572866
| KnownClass::Classmethod
2867+
| KnownClass::Awaitable
2868+
| KnownClass::Generator
28582869
| KnownClass::Deprecated
28592870
| KnownClass::Super
28602871
| KnownClass::Enum
@@ -2872,6 +2883,7 @@ impl KnownClass {
28722883
| KnownClass::UnionType
28732884
| KnownClass::GeneratorType
28742885
| KnownClass::AsyncGeneratorType
2886+
| KnownClass::CoroutineType
28752887
| KnownClass::NoneType
28762888
| KnownClass::Any
28772889
| KnownClass::StdlibAlias
@@ -2917,7 +2929,11 @@ impl KnownClass {
29172929
/// 2. It's probably more performant.
29182930
const fn is_protocol(self) -> bool {
29192931
match self {
2920-
Self::SupportsIndex | Self::Iterable | Self::Iterator => true,
2932+
Self::SupportsIndex
2933+
| Self::Iterable
2934+
| Self::Iterator
2935+
| Self::Awaitable
2936+
| Self::Generator => true,
29212937

29222938
Self::Any
29232939
| Self::Bool
@@ -2946,6 +2962,7 @@ impl KnownClass {
29462962
| Self::GenericAlias
29472963
| Self::GeneratorType
29482964
| Self::AsyncGeneratorType
2965+
| Self::CoroutineType
29492966
| Self::ModuleType
29502967
| Self::FunctionType
29512968
| Self::MethodType
@@ -3011,6 +3028,8 @@ impl KnownClass {
30113028
Self::ExceptionGroup => "ExceptionGroup",
30123029
Self::Staticmethod => "staticmethod",
30133030
Self::Classmethod => "classmethod",
3031+
Self::Awaitable => "Awaitable",
3032+
Self::Generator => "Generator",
30143033
Self::Deprecated => "deprecated",
30153034
Self::GenericAlias => "GenericAlias",
30163035
Self::ModuleType => "ModuleType",
@@ -3021,6 +3040,7 @@ impl KnownClass {
30213040
Self::WrapperDescriptorType => "WrapperDescriptorType",
30223041
Self::GeneratorType => "GeneratorType",
30233042
Self::AsyncGeneratorType => "AsyncGeneratorType",
3043+
Self::CoroutineType => "CoroutineType",
30243044
Self::NamedTuple => "NamedTuple",
30253045
Self::NoneType => "NoneType",
30263046
Self::SpecialForm => "_SpecialForm",
@@ -3281,11 +3301,14 @@ impl KnownClass {
32813301
| Self::MethodType
32823302
| Self::GeneratorType
32833303
| Self::AsyncGeneratorType
3304+
| Self::CoroutineType
32843305
| Self::MethodWrapperType
32853306
| Self::UnionType
32863307
| Self::WrapperDescriptorType => KnownModule::Types,
32873308
Self::NoneType => KnownModule::Typeshed,
32883309
Self::Any
3310+
| Self::Awaitable
3311+
| Self::Generator
32893312
| Self::SpecialForm
32903313
| Self::TypeVar
32913314
| Self::NamedTuple
@@ -3366,12 +3389,15 @@ impl KnownClass {
33663389
| Self::ExceptionGroup
33673390
| Self::Staticmethod
33683391
| Self::Classmethod
3392+
| Self::Awaitable
3393+
| Self::Generator
33693394
| Self::Deprecated
33703395
| Self::GenericAlias
33713396
| Self::ModuleType
33723397
| Self::FunctionType
33733398
| Self::GeneratorType
33743399
| Self::AsyncGeneratorType
3400+
| Self::CoroutineType
33753401
| Self::MethodType
33763402
| Self::MethodWrapperType
33773403
| Self::WrapperDescriptorType
@@ -3443,6 +3469,7 @@ impl KnownClass {
34433469
| Self::WrapperDescriptorType
34443470
| Self::GeneratorType
34453471
| Self::AsyncGeneratorType
3472+
| Self::CoroutineType
34463473
| Self::SpecialForm
34473474
| Self::ChainMap
34483475
| Self::Counter
@@ -3457,6 +3484,8 @@ impl KnownClass {
34573484
| Self::ExceptionGroup
34583485
| Self::Staticmethod
34593486
| Self::Classmethod
3487+
| Self::Awaitable
3488+
| Self::Generator
34603489
| Self::Deprecated
34613490
| Self::TypeVar
34623491
| Self::ParamSpec
@@ -3513,12 +3542,15 @@ impl KnownClass {
35133542
"ExceptionGroup" => Self::ExceptionGroup,
35143543
"staticmethod" => Self::Staticmethod,
35153544
"classmethod" => Self::Classmethod,
3545+
"Awaitable" => Self::Awaitable,
3546+
"Generator" => Self::Generator,
35163547
"deprecated" => Self::Deprecated,
35173548
"GenericAlias" => Self::GenericAlias,
35183549
"NoneType" => Self::NoneType,
35193550
"ModuleType" => Self::ModuleType,
35203551
"GeneratorType" => Self::GeneratorType,
35213552
"AsyncGeneratorType" => Self::AsyncGeneratorType,
3553+
"CoroutineType" => Self::CoroutineType,
35223554
"FunctionType" => Self::FunctionType,
35233555
"MethodType" => Self::MethodType,
35243556
"UnionType" => Self::UnionType,
@@ -3623,6 +3655,7 @@ impl KnownClass {
36233655
| Self::UnionType
36243656
| Self::GeneratorType
36253657
| Self::AsyncGeneratorType
3658+
| Self::CoroutineType
36263659
| Self::WrapperDescriptorType
36273660
| Self::Field
36283661
| Self::KwOnly
@@ -3642,6 +3675,7 @@ impl KnownClass {
36423675
| Self::Iterable
36433676
| Self::Iterator
36443677
| Self::NewType => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions),
3678+
Self::Awaitable | Self::Generator => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions | KnownModule::Abc),
36453679
Self::Deprecated => matches!(module, KnownModule::Warnings | KnownModule::TypingExtensions),
36463680

36473681
}

crates/ty_python_semantic/src/types/function.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,12 +341,16 @@ impl<'db> OverloadLiteral<'db> {
341341
GenericContext::from_type_params(db, index, type_params)
342342
});
343343

344+
let index = semantic_index(db, scope.file(db));
345+
let is_generator = scope.file_scope_id(db).is_generator_function(index);
346+
344347
Signature::from_function(
345348
db,
346349
generic_context,
347350
inherited_generic_context,
348351
definition,
349352
function_stmt_node,
353+
is_generator,
350354
)
351355
}
352356

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ use crate::types::{
119119
CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType,
120120
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard,
121121
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
122-
Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
122+
Parameters, Protocol, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
123123
TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance,
124124
TypeVarKind, TypeVarVariance, UnionBuilder, UnionType, binding_type, todo_type,
125125
};
@@ -6112,8 +6112,29 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
61126112
node_index: _,
61136113
value,
61146114
} = await_expression;
6115-
self.infer_expression(value);
6116-
todo_type!("generic `typing.Awaitable` type")
6115+
let value_ty = self.infer_expression(value);
6116+
6117+
value_ty
6118+
.try_call_dunder(self.db(), "__await__", CallArguments::none())
6119+
.map_or(Type::unknown(), |result| {
6120+
let generator_ty = result.return_type(self.db());
6121+
6122+
if let Type::ProtocolInstance(instance) = generator_ty {
6123+
if let Protocol::FromClass(class) = instance.inner {
6124+
if class.is_known(self.db(), KnownClass::Generator) {
6125+
if let Some(specialization) =
6126+
class.class_literal_specialized(self.db(), None).1
6127+
{
6128+
if let [_, _, return_ty] = specialization.types(self.db()) {
6129+
return *return_ty;
6130+
}
6131+
}
6132+
}
6133+
}
6134+
}
6135+
6136+
Type::unknown()
6137+
})
61176138
}
61186139

61196140
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use smallvec::{SmallVec, smallvec_inline};
1818
use super::{DynamicType, Type, TypeTransformer, TypeVarVariance, definition_expression_type};
1919
use crate::semantic_index::definition::Definition;
2020
use crate::types::generics::{GenericContext, walk_generic_context};
21-
use crate::types::{TypeMapping, TypeRelation, TypeVarInstance, todo_type};
21+
use crate::types::{KnownClass, TypeMapping, TypeRelation, TypeVarInstance, todo_type};
2222
use crate::{Db, FxOrderSet};
2323
use ruff_python_ast::{self as ast, name::Name};
2424

@@ -320,14 +320,18 @@ impl<'db> Signature<'db> {
320320
inherited_generic_context: Option<GenericContext<'db>>,
321321
definition: Definition<'db>,
322322
function_node: &ast::StmtFunctionDef,
323+
is_async_generator: bool,
323324
) -> Self {
324325
let parameters =
325326
Parameters::from_parameters(db, definition, function_node.parameters.as_ref());
326327
let return_ty = function_node.returns.as_ref().map(|returns| {
327-
if function_node.is_async {
328-
todo_type!("generic types.CoroutineType")
328+
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref());
329+
330+
if function_node.is_async && !is_async_generator {
331+
KnownClass::CoroutineType
332+
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
329333
} else {
330-
definition_expression_type(db, definition, returns.as_ref())
334+
plain_return_ty
331335
}
332336
});
333337
let legacy_generic_context =

0 commit comments

Comments
 (0)