Skip to content

Commit 7e83283

Browse files
committed
[ty] Support async/await
1 parent 8c0743d commit 7e83283

File tree

7 files changed

+99
-12
lines changed

7 files changed

+99
-12
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# `async` / `await`
2+
3+
## Basic
4+
5+
```py
6+
async def retrieve() -> int:
7+
return 42
8+
9+
reveal_type(retrieve) # revealed: def retrieve() -> CoroutineType[Any, Any, int]
10+
11+
async def main():
12+
result = await retrieve()
13+
14+
reveal_type(result) # revealed: int
15+
```

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ reveal_type(get_int()) # revealed: int
1515
async def get_int_async() -> int:
1616
return 42
1717

18-
# TODO: we don't yet support `types.CoroutineType`, should be generic `Coroutine[Any, Any, int]`
19-
reveal_type(get_int_async()) # revealed: @Todo(generic types.CoroutineType)
18+
reveal_type(get_int_async()) # revealed: CoroutineType[Any, Any, int]
2019
```
2120

2221
## Generic

crates/ty_python_semantic/resources/mdtest/function/return_type.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ def i() -> typing.Iterable:
455455
yield 42
456456

457457
def i2() -> typing.Generator:
458-
yield from i()
458+
yield from i() # error: [not-iterable]
459459

460460
def j() -> str: # error: [invalid-return-type]
461461
yield 42

crates/ty_python_semantic/resources/mdtest/snapshots/return_type.md_-_Function_return_type_-_Generator_functions_(d9ed06b61b14fd4c).snap

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
2828
14 | yield 42
2929
15 |
3030
16 | def i2() -> typing.Generator:
31-
17 | yield from i()
31+
17 | yield from i() # error: [not-iterable]
3232
18 |
3333
19 | def j() -> str: # error: [invalid-return-type]
3434
20 | yield 42
@@ -53,11 +53,26 @@ mdtest path: crates/ty_python_semantic/resources/mdtest/function/return_type.md
5353

5454
# Diagnostics
5555

56+
```
57+
error[not-iterable]: Object of type `Iterable[Unknown] | CoroutineType[Any, Any, AsyncIterable[Unknown]]` may not be iterable
58+
--> src/mdtest_snippet.py:17:16
59+
|
60+
16 | def i2() -> typing.Generator:
61+
17 | yield from i() # error: [not-iterable]
62+
| ^^^
63+
18 |
64+
19 | def j() -> str: # error: [invalid-return-type]
65+
|
66+
info: It may not have an `__iter__` method and it doesn't have a `__getitem__` method
67+
info: rule `not-iterable` is enabled by default
68+
69+
```
70+
5671
```
5772
error[invalid-return-type]: Return type does not match returned value
5873
--> src/mdtest_snippet.py:19:12
5974
|
60-
17 | yield from i()
75+
17 | yield from i() # error: [not-iterable]
6176
18 |
6277
19 | def j() -> str: # error: [invalid-return-type]
6378
| ^^^ expected `str`, found `types.GeneratorType`

crates/ty_python_semantic/src/types/class.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2610,10 +2610,13 @@ pub enum KnownClass {
26102610
UnionType,
26112611
GeneratorType,
26122612
AsyncGeneratorType,
2613+
CoroutineType,
26132614
// Typeshed
26142615
NoneType, // Part of `types` for Python >= 3.10
26152616
// Typing
26162617
Any,
2618+
Awaitable,
2619+
Generator,
26172620
Deprecated,
26182621
StdlibAlias,
26192622
SpecialForm,
@@ -2736,6 +2739,9 @@ impl KnownClass {
27362739
| Self::NotImplementedType
27372740
| Self::Staticmethod
27382741
| Self::Classmethod
2742+
| Self::Awaitable
2743+
| Self::Generator
2744+
| Self::CoroutineType
27392745
| Self::Deprecated
27402746
| Self::Field
27412747
| Self::KwOnly
@@ -2771,7 +2777,8 @@ impl KnownClass {
27712777
| Self::Super
27722778
| Self::GenericAlias
27732779
| Self::Deque
2774-
| Self::Bytes => true,
2780+
| Self::Bytes
2781+
| Self::CoroutineType => true,
27752782

27762783
// It doesn't really make sense to ask the question for `@final` types,
27772784
// since these are "more than solid bases". But we'll anyway infer a `@final`
@@ -2807,6 +2814,8 @@ impl KnownClass {
28072814
// with length >2, or anything that is implemented in pure Python, is not a solid base.
28082815
Self::ABCMeta
28092816
| Self::Any
2817+
| Self::Awaitable
2818+
| Self::Generator
28102819
| Self::Enum
28112820
| Self::EnumType
28122821
| Self::Auto
@@ -2855,6 +2864,8 @@ impl KnownClass {
28552864
| KnownClass::ExceptionGroup
28562865
| KnownClass::Staticmethod
28572866
| KnownClass::Classmethod
2867+
| KnownClass::Awaitable
2868+
| KnownClass::Generator
28582869
| KnownClass::Deprecated
28592870
| KnownClass::Super
28602871
| KnownClass::Enum
@@ -2872,6 +2883,7 @@ impl KnownClass {
28722883
| KnownClass::UnionType
28732884
| KnownClass::GeneratorType
28742885
| KnownClass::AsyncGeneratorType
2886+
| KnownClass::CoroutineType
28752887
| KnownClass::NoneType
28762888
| KnownClass::Any
28772889
| KnownClass::StdlibAlias
@@ -2917,7 +2929,11 @@ impl KnownClass {
29172929
/// 2. It's probably more performant.
29182930
const fn is_protocol(self) -> bool {
29192931
match self {
2920-
Self::SupportsIndex | Self::Iterable | Self::Iterator => true,
2932+
Self::SupportsIndex
2933+
| Self::Iterable
2934+
| Self::Iterator
2935+
| Self::Awaitable
2936+
| Self::Generator => true,
29212937

29222938
Self::Any
29232939
| Self::Bool
@@ -2946,6 +2962,7 @@ impl KnownClass {
29462962
| Self::GenericAlias
29472963
| Self::GeneratorType
29482964
| Self::AsyncGeneratorType
2965+
| Self::CoroutineType
29492966
| Self::ModuleType
29502967
| Self::FunctionType
29512968
| Self::MethodType
@@ -3011,6 +3028,8 @@ impl KnownClass {
30113028
Self::ExceptionGroup => "ExceptionGroup",
30123029
Self::Staticmethod => "staticmethod",
30133030
Self::Classmethod => "classmethod",
3031+
Self::Awaitable => "Awaitable",
3032+
Self::Generator => "Generator",
30143033
Self::Deprecated => "deprecated",
30153034
Self::GenericAlias => "GenericAlias",
30163035
Self::ModuleType => "ModuleType",
@@ -3021,6 +3040,7 @@ impl KnownClass {
30213040
Self::WrapperDescriptorType => "WrapperDescriptorType",
30223041
Self::GeneratorType => "GeneratorType",
30233042
Self::AsyncGeneratorType => "AsyncGeneratorType",
3043+
Self::CoroutineType => "CoroutineType",
30243044
Self::NamedTuple => "NamedTuple",
30253045
Self::NoneType => "NoneType",
30263046
Self::SpecialForm => "_SpecialForm",
@@ -3281,11 +3301,14 @@ impl KnownClass {
32813301
| Self::MethodType
32823302
| Self::GeneratorType
32833303
| Self::AsyncGeneratorType
3304+
| Self::CoroutineType
32843305
| Self::MethodWrapperType
32853306
| Self::UnionType
32863307
| Self::WrapperDescriptorType => KnownModule::Types,
32873308
Self::NoneType => KnownModule::Typeshed,
32883309
Self::Any
3310+
| Self::Awaitable
3311+
| Self::Generator
32893312
| Self::SpecialForm
32903313
| Self::TypeVar
32913314
| Self::NamedTuple
@@ -3366,12 +3389,15 @@ impl KnownClass {
33663389
| Self::ExceptionGroup
33673390
| Self::Staticmethod
33683391
| Self::Classmethod
3392+
| Self::Awaitable
3393+
| Self::Generator
33693394
| Self::Deprecated
33703395
| Self::GenericAlias
33713396
| Self::ModuleType
33723397
| Self::FunctionType
33733398
| Self::GeneratorType
33743399
| Self::AsyncGeneratorType
3400+
| Self::CoroutineType
33753401
| Self::MethodType
33763402
| Self::MethodWrapperType
33773403
| Self::WrapperDescriptorType
@@ -3443,6 +3469,7 @@ impl KnownClass {
34433469
| Self::WrapperDescriptorType
34443470
| Self::GeneratorType
34453471
| Self::AsyncGeneratorType
3472+
| Self::CoroutineType
34463473
| Self::SpecialForm
34473474
| Self::ChainMap
34483475
| Self::Counter
@@ -3457,6 +3484,8 @@ impl KnownClass {
34573484
| Self::ExceptionGroup
34583485
| Self::Staticmethod
34593486
| Self::Classmethod
3487+
| Self::Awaitable
3488+
| Self::Generator
34603489
| Self::Deprecated
34613490
| Self::TypeVar
34623491
| Self::ParamSpec
@@ -3513,12 +3542,15 @@ impl KnownClass {
35133542
"ExceptionGroup" => Self::ExceptionGroup,
35143543
"staticmethod" => Self::Staticmethod,
35153544
"classmethod" => Self::Classmethod,
3545+
"Awaitable" => Self::Awaitable,
3546+
"Generator" => Self::Generator,
35163547
"deprecated" => Self::Deprecated,
35173548
"GenericAlias" => Self::GenericAlias,
35183549
"NoneType" => Self::NoneType,
35193550
"ModuleType" => Self::ModuleType,
35203551
"GeneratorType" => Self::GeneratorType,
35213552
"AsyncGeneratorType" => Self::AsyncGeneratorType,
3553+
"CoroutineType" => Self::CoroutineType,
35223554
"FunctionType" => Self::FunctionType,
35233555
"MethodType" => Self::MethodType,
35243556
"UnionType" => Self::UnionType,
@@ -3623,6 +3655,7 @@ impl KnownClass {
36233655
| Self::UnionType
36243656
| Self::GeneratorType
36253657
| Self::AsyncGeneratorType
3658+
| Self::CoroutineType
36263659
| Self::WrapperDescriptorType
36273660
| Self::Field
36283661
| Self::KwOnly
@@ -3642,6 +3675,7 @@ impl KnownClass {
36423675
| Self::Iterable
36433676
| Self::Iterator
36443677
| Self::NewType => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions),
3678+
Self::Awaitable | Self::Generator => matches!(module, KnownModule::Typing | KnownModule::TypingExtensions | KnownModule::Abc),
36453679
Self::Deprecated => matches!(module, KnownModule::Warnings | KnownModule::TypingExtensions),
36463680

36473681
}

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ use crate::types::{
119119
CallDunderError, CallableType, ClassLiteral, ClassType, DataclassParams, DynamicType,
120120
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, LintDiagnosticGuard,
121121
MemberLookupPolicy, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm,
122-
Parameters, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
122+
Parameters, Protocol, SpecialFormType, SubclassOfType, Truthiness, Type, TypeAliasType,
123123
TypeAndQualifiers, TypeIsType, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance,
124124
TypeVarKind, TypeVarVariance, UnionBuilder, UnionType, binding_type, todo_type,
125125
};
@@ -6112,8 +6112,29 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
61126112
node_index: _,
61136113
value,
61146114
} = await_expression;
6115-
self.infer_expression(value);
6116-
todo_type!("generic `typing.Awaitable` type")
6115+
let value_ty = self.infer_expression(value);
6116+
6117+
value_ty
6118+
.try_call_dunder(self.db(), "__await__", CallArguments::none())
6119+
.map_or(Type::unknown(), |result| {
6120+
let generator_ty = result.return_type(self.db());
6121+
6122+
if let Type::ProtocolInstance(instance) = generator_ty {
6123+
if let Protocol::FromClass(class) = instance.inner {
6124+
if class.is_known(self.db(), KnownClass::Generator) {
6125+
if let Some(specialization) =
6126+
class.class_literal_specialized(self.db(), None).1
6127+
{
6128+
if let [_, _, return_ty] = specialization.types(self.db()) {
6129+
return *return_ty;
6130+
}
6131+
}
6132+
}
6133+
}
6134+
}
6135+
6136+
Type::unknown()
6137+
})
61176138
}
61186139

61196140
// Perform narrowing with applicable constraints between the current scope and the enclosing scope.

crates/ty_python_semantic/src/types/signatures.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use smallvec::{SmallVec, smallvec_inline};
1818
use super::{DynamicType, Type, TypeTransformer, TypeVarVariance, definition_expression_type};
1919
use crate::semantic_index::definition::Definition;
2020
use crate::types::generics::{GenericContext, walk_generic_context};
21-
use crate::types::{TypeMapping, TypeRelation, TypeVarInstance, todo_type};
21+
use crate::types::{KnownClass, TypeMapping, TypeRelation, TypeVarInstance, todo_type};
2222
use crate::{Db, FxOrderSet};
2323
use ruff_python_ast::{self as ast, name::Name};
2424

@@ -325,7 +325,10 @@ impl<'db> Signature<'db> {
325325
Parameters::from_parameters(db, definition, function_node.parameters.as_ref());
326326
let return_ty = function_node.returns.as_ref().map(|returns| {
327327
if function_node.is_async {
328-
todo_type!("generic types.CoroutineType")
328+
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref());
329+
330+
KnownClass::CoroutineType
331+
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
329332
} else {
330333
definition_expression_type(db, definition, returns.as_ref())
331334
}

0 commit comments

Comments
 (0)