Skip to content

Commit 64d2e16

Browse files
committed
[ty] Fix panics when pulling types for various special forms that have the wrong number of parameters
1 parent d38866f commit 64d2e16

File tree

3 files changed

+152
-80
lines changed

3 files changed

+152
-80
lines changed

crates/ty_python_semantic/resources/mdtest/type_api.md

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,12 @@ def negate(n1: Not[int], n2: Not[Not[int]], n3: Not[Not[Not[int]]]) -> None:
2323
reveal_type(n2) # revealed: int
2424
reveal_type(n3) # revealed: ~int
2525

26-
# error: "Special form `ty_extensions.Not` expected exactly one type parameter"
26+
# error: "Special form `ty_extensions.Not` expected exactly 1 type argument, got 2"
2727
n: Not[int, str]
28+
# error: [invalid-type-form] "Special form `ty_extensions.Not` expected exactly 1 type argument, got 0"
29+
o: Not[()]
30+
31+
p: Not[(int,)]
2832

2933
def static_truthiness(not_one: Not[Literal[1]]) -> None:
3034
# TODO: `bool` is not incorrect, but these would ideally be `Literal[True]` and `Literal[False]`
@@ -371,8 +375,6 @@ static_assert(not is_single_valued(Literal["a"] | Literal["b"]))
371375

372376
## `TypeOf`
373377

374-
<!-- pull-types:skip -->
375-
376378
We use `TypeOf` to get the inferred type of an expression. This is useful when we want to refer to
377379
it in a type expression. For example, if we want to make sure that the class literal type `str` is a
378380
subtype of `type[str]`, we can not use `is_subtype_of(str, type[str])`, as that would test if the
@@ -398,13 +400,13 @@ class Derived(Base): ...
398400
```py
399401
def type_of_annotation() -> None:
400402
t1: TypeOf[Base] = Base
401-
t2: TypeOf[Base] = Derived # error: [invalid-assignment]
403+
t2: TypeOf[(Base,)] = Derived # error: [invalid-assignment]
402404

403405
# Note how this is different from `type[…]` which includes subclasses:
404406
s1: type[Base] = Base
405407
s2: type[Base] = Derived # no error here
406408

407-
# error: "Special form `ty_extensions.TypeOf` expected exactly one type parameter"
409+
# error: "Special form `ty_extensions.TypeOf` expected exactly 1 type argument, got 3"
408410
t: TypeOf[int, str, bytes]
409411

410412
# error: [invalid-type-form] "`ty_extensions.TypeOf` requires exactly one argument when used in a type expression"
@@ -414,8 +416,6 @@ def f(x: TypeOf) -> None:
414416

415417
## `CallableTypeOf`
416418

417-
<!-- pull-types:skip -->
418-
419419
The `CallableTypeOf` special form can be used to extract the `Callable` structural type inhabited by
420420
a given callable object. This can be used to get the externally visibly signature of the object,
421421
which can then be used to test various type properties.
@@ -434,15 +434,23 @@ def f2() -> int:
434434
def f3(x: int, y: str) -> None:
435435
return
436436

437-
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly one type parameter"
437+
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly 1 type argument, got 2"
438438
c1: CallableTypeOf[f1, f2]
439439

440440
# error: [invalid-type-form] "Expected the first argument to `ty_extensions.CallableTypeOf` to be a callable object, but got an object of type `Literal["foo"]`"
441441
c2: CallableTypeOf["foo"]
442442

443+
# error: [invalid-type-form] "Expected the first argument to `ty_extensions.CallableTypeOf` to be a callable object, but got an object of type `Literal["foo"]`"
444+
c20: CallableTypeOf[("foo",)]
445+
443446
# error: [invalid-type-form] "`ty_extensions.CallableTypeOf` requires exactly one argument when used in a type expression"
444447
def f(x: CallableTypeOf) -> None:
445448
reveal_type(x) # revealed: Unknown
449+
450+
c3: CallableTypeOf[(f3,)]
451+
452+
# error: [invalid-type-form] "Special form `ty_extensions.CallableTypeOf` expected exactly 1 type argument, got 0"
453+
c4: CallableTypeOf[()]
446454
```
447455

448456
Using it in annotation to reveal the signature of the callable object:

crates/ty_python_semantic/src/types/diagnostic.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1888,6 +1888,26 @@ pub(crate) fn report_invalid_arguments_to_annotated(
18881888
));
18891889
}
18901890

1891+
pub(crate) fn report_invalid_argument_number_to_special_form(
1892+
context: &InferContext,
1893+
subscript: &ast::ExprSubscript,
1894+
special_form: SpecialFormType,
1895+
received_arguments: usize,
1896+
expected_arguments: u8,
1897+
) {
1898+
let noun = if expected_arguments == 1 {
1899+
"type argument"
1900+
} else {
1901+
"type arguments"
1902+
};
1903+
if let Some(builder) = context.report_lint(&INVALID_TYPE_FORM, subscript) {
1904+
builder.into_diagnostic(format_args!(
1905+
"Special form `{special_form}` expected exactly {expected_arguments} {noun}, \
1906+
got {received_arguments}",
1907+
));
1908+
}
1909+
}
1910+
18911911
pub(crate) fn report_bad_argument_to_get_protocol_members(
18921912
context: &InferContext,
18931913
call: &ast::ExprCall,

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 116 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ use crate::types::diagnostic::{
8282
INVALID_TYPE_VARIABLE_CONSTRAINTS, POSSIBLY_UNBOUND_IMPLICIT_CALL, POSSIBLY_UNBOUND_IMPORT,
8383
TypeCheckDiagnostics, UNDEFINED_REVEAL, UNRESOLVED_ATTRIBUTE, UNRESOLVED_IMPORT,
8484
UNRESOLVED_REFERENCE, UNSUPPORTED_OPERATOR, report_implicit_return_type,
85-
report_invalid_arguments_to_annotated, report_invalid_arguments_to_callable,
86-
report_invalid_assignment, report_invalid_attribute_assignment,
87-
report_invalid_generator_function_return_type, report_invalid_return_type,
88-
report_possibly_unbound_attribute,
85+
report_invalid_argument_number_to_special_form, report_invalid_arguments_to_annotated,
86+
report_invalid_arguments_to_callable, report_invalid_assignment,
87+
report_invalid_attribute_assignment, report_invalid_generator_function_return_type,
88+
report_invalid_return_type, report_possibly_unbound_attribute,
8989
};
9090
use crate::types::function::{
9191
FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral,
@@ -9431,24 +9431,33 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
94319431
}
94329432

94339433
// Type API special forms
9434-
SpecialFormType::Not => match arguments_slice {
9435-
ast::Expr::Tuple(tuple) => {
9436-
for element in tuple {
9437-
self.infer_type_expression(element);
9438-
}
9439-
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
9440-
builder.into_diagnostic(format_args!(
9441-
"Special form `{special_form}` expected exactly one type parameter",
9442-
));
9434+
SpecialFormType::Not => {
9435+
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
9436+
&*tuple.elts
9437+
} else {
9438+
std::slice::from_ref(arguments_slice)
9439+
};
9440+
let num_arguments = arguments.len();
9441+
let negated_type = if num_arguments == 1 {
9442+
self.infer_type_expression(&arguments[0]).negate(db)
9443+
} else {
9444+
for argument in arguments {
9445+
self.infer_type_expression(argument);
94439446
}
9444-
self.store_expression_type(arguments_slice, Type::unknown());
9447+
report_invalid_argument_number_to_special_form(
9448+
&self.context,
9449+
subscript,
9450+
special_form,
9451+
num_arguments,
9452+
1,
9453+
);
94459454
Type::unknown()
9455+
};
9456+
if arguments_slice.is_tuple_expr() {
9457+
self.store_expression_type(arguments_slice, negated_type);
94469458
}
9447-
_ => {
9448-
let argument_type = self.infer_type_expression(arguments_slice);
9449-
argument_type.negate(db)
9450-
}
9451-
},
9459+
negated_type
9460+
}
94529461
SpecialFormType::Intersection => {
94539462
let elements = match arguments_slice {
94549463
ast::Expr::Tuple(tuple) => Either::Left(tuple.iter()),
@@ -9466,70 +9475,105 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
94669475
}
94679476
ty
94689477
}
9469-
SpecialFormType::TypeOf => match arguments_slice {
9470-
ast::Expr::Tuple(_) => {
9471-
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
9472-
builder.into_diagnostic(format_args!(
9473-
"Special form `{special_form}` expected exactly one type parameter",
9474-
));
9478+
SpecialFormType::TypeOf => {
9479+
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
9480+
&*tuple.elts
9481+
} else {
9482+
std::slice::from_ref(arguments_slice)
9483+
};
9484+
let num_arguments = arguments.len();
9485+
let type_of_type = if num_arguments == 1 {
9486+
// N.B. This uses `infer_expression` rather than `infer_type_expression`
9487+
self.infer_expression(&arguments[0])
9488+
} else {
9489+
for argument in arguments {
9490+
self.infer_type_expression(argument);
94759491
}
9492+
report_invalid_argument_number_to_special_form(
9493+
&self.context,
9494+
subscript,
9495+
special_form,
9496+
num_arguments,
9497+
1,
9498+
);
94769499
Type::unknown()
9500+
};
9501+
if arguments_slice.is_tuple_expr() {
9502+
self.store_expression_type(arguments_slice, type_of_type);
94779503
}
9478-
_ => {
9479-
// NB: This calls `infer_expression` instead of `infer_type_expression`.
9504+
type_of_type
9505+
}
94809506

9481-
self.infer_expression(arguments_slice)
9482-
}
9483-
},
9484-
SpecialFormType::CallableTypeOf => match arguments_slice {
9485-
ast::Expr::Tuple(_) => {
9486-
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
9487-
builder.into_diagnostic(format_args!(
9488-
"Special form `{special_form}` expected exactly one type parameter",
9489-
));
9507+
SpecialFormType::CallableTypeOf => {
9508+
let arguments = if let ast::Expr::Tuple(tuple) = arguments_slice {
9509+
&*tuple.elts
9510+
} else {
9511+
std::slice::from_ref(arguments_slice)
9512+
};
9513+
let num_arguments = arguments.len();
9514+
9515+
if num_arguments != 1 {
9516+
for argument in arguments {
9517+
self.infer_expression(argument);
94909518
}
9491-
Type::unknown()
9519+
report_invalid_argument_number_to_special_form(
9520+
&self.context,
9521+
subscript,
9522+
special_form,
9523+
num_arguments,
9524+
1,
9525+
);
9526+
if arguments_slice.is_tuple_expr() {
9527+
self.store_expression_type(arguments_slice, Type::unknown());
9528+
}
9529+
return Type::unknown();
94929530
}
9493-
_ => {
9494-
let argument_type = self.infer_expression(arguments_slice);
9495-
let bindings = argument_type.bindings(db);
9496-
9497-
// SAFETY: This is enforced by the constructor methods on `Bindings` even in
9498-
// the case of a non-callable union.
9499-
let callable_binding = bindings
9500-
.into_iter()
9501-
.next()
9502-
.expect("`Bindings` should have at least one `CallableBinding`");
9503-
9504-
let mut signature_iter = callable_binding.into_iter().map(|binding| {
9505-
if argument_type.is_bound_method() {
9506-
binding.signature.bind_self()
9507-
} else {
9508-
binding.signature.clone()
9509-
}
9510-
});
95119531

9512-
let Some(signature) = signature_iter.next() else {
9513-
if let Some(builder) = self
9514-
.context
9515-
.report_lint(&INVALID_TYPE_FORM, arguments_slice)
9516-
{
9517-
builder.into_diagnostic(format_args!(
9518-
"Expected the first argument to `{special_form}` \
9532+
let argument_type = self.infer_expression(&arguments[0]);
9533+
let bindings = argument_type.bindings(db);
9534+
9535+
// SAFETY: This is enforced by the constructor methods on `Bindings` even in
9536+
// the case of a non-callable union.
9537+
let callable_binding = bindings
9538+
.into_iter()
9539+
.next()
9540+
.expect("`Bindings` should have at least one `CallableBinding`");
9541+
9542+
let mut signature_iter = callable_binding.into_iter().map(|binding| {
9543+
if argument_type.is_bound_method() {
9544+
binding.signature.bind_self()
9545+
} else {
9546+
binding.signature.clone()
9547+
}
9548+
});
9549+
9550+
let Some(signature) = signature_iter.next() else {
9551+
if let Some(builder) = self
9552+
.context
9553+
.report_lint(&INVALID_TYPE_FORM, arguments_slice)
9554+
{
9555+
builder.into_diagnostic(format_args!(
9556+
"Expected the first argument to `{special_form}` \
95199557
to be a callable object, \
95209558
but got an object of type `{actual_type}`",
9521-
actual_type = argument_type.display(db)
9522-
));
9523-
}
9524-
return Type::unknown();
9525-
};
9559+
actual_type = argument_type.display(db)
9560+
));
9561+
}
9562+
if arguments_slice.is_tuple_expr() {
9563+
self.store_expression_type(arguments_slice, Type::unknown());
9564+
}
9565+
return Type::unknown();
9566+
};
95269567

9527-
let signature = CallableSignature::from_overloads(
9528-
std::iter::once(signature).chain(signature_iter),
9529-
);
9530-
Type::Callable(CallableType::new(db, signature, false))
9568+
let signature = CallableSignature::from_overloads(
9569+
std::iter::once(signature).chain(signature_iter),
9570+
);
9571+
let callable_type_of = Type::Callable(CallableType::new(db, signature, false));
9572+
if arguments_slice.is_tuple_expr() {
9573+
self.store_expression_type(arguments_slice, callable_type_of);
95319574
}
9532-
},
9575+
callable_type_of
9576+
}
95339577

95349578
SpecialFormType::ChainMap => self.infer_parameterized_legacy_typing_alias(
95359579
subscript,

0 commit comments

Comments
 (0)