Skip to content

Commit 9ac39ce

Browse files
authored
[ty] Ban protocols from inheriting from non-protocol generic classes (#19941)
1 parent f4d8826 commit 9ac39ce

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

crates/ty_python_semantic/resources/mdtest/protocols.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ class AlsoInvalid(MyProtocol, OtherProtocol, NotAProtocol, Protocol): ...
150150

151151
# revealed: tuple[<class 'AlsoInvalid'>, <class 'MyProtocol'>, <class 'OtherProtocol'>, <class 'NotAProtocol'>, typing.Protocol, typing.Generic, <class 'object'>]
152152
reveal_type(AlsoInvalid.__mro__)
153+
154+
class NotAGenericProtocol[T]: ...
155+
156+
# error: [invalid-protocol] "Protocol class `StillInvalid` cannot inherit from non-protocol class `NotAGenericProtocol`"
157+
class StillInvalid(NotAGenericProtocol[int], Protocol): ...
158+
159+
# revealed: tuple[<class 'StillInvalid'>, <class 'NotAGenericProtocol[int]'>, typing.Protocol, typing.Generic, <class 'object'>]
160+
reveal_type(StillInvalid.__mro__)
153161
```
154162

155163
But two exceptions to this rule are `object` and `Generic`:

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,13 +1117,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
11171117
// - Check for inheritance from a `@final` classes
11181118
// - If the class is a protocol class: check for inheritance from a non-protocol class
11191119
for (i, base_class) in class.explicit_bases(self.db()).iter().enumerate() {
1120-
if let Some((class, solid_base)) = base_class
1121-
.to_class_type(self.db())
1122-
.and_then(|class| Some((class, class.nearest_solid_base(self.db())?)))
1123-
{
1124-
solid_bases.insert(solid_base, i, class.class_literal(self.db()).0);
1125-
}
1126-
11271120
let base_class = match base_class {
11281121
Type::SpecialForm(SpecialFormType::Generic) => {
11291122
if let Some(builder) = self
@@ -1155,13 +1148,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
11551148
);
11561149
continue;
11571150
}
1158-
Type::ClassLiteral(class) => class,
1159-
// dynamic/unknown bases are never `@final`
1151+
Type::ClassLiteral(class) => ClassType::NonGeneric(*class),
1152+
Type::GenericAlias(class) => ClassType::Generic(*class),
11601153
_ => continue,
11611154
};
11621155

1156+
if let Some(solid_base) = base_class.nearest_solid_base(self.db()) {
1157+
solid_bases.insert(solid_base, i, base_class.class_literal(self.db()).0);
1158+
}
1159+
11631160
if is_protocol
1164-
&& !(base_class.is_protocol(self.db())
1161+
&& !(base_class.class_literal(self.db()).0.is_protocol(self.db())
11651162
|| base_class.is_known(self.db(), KnownClass::Object))
11661163
{
11671164
if let Some(builder) = self

0 commit comments

Comments
 (0)