Skip to content

Commit f57163a

Browse files
committed
[ty] Eagerly simplify 'True' and 'False' constraints
1 parent 6802c47 commit f57163a

File tree

6 files changed

+91
-31
lines changed

6 files changed

+91
-31
lines changed

crates/ty_python_semantic/resources/mdtest/narrow/while.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,17 @@ while x != 1:
5959

6060
x = next_item()
6161
```
62+
63+
## With `break` statements
64+
65+
```py
66+
def next_item() -> int | None:
67+
return 1
68+
69+
while True:
70+
x = next_item()
71+
if x is not None:
72+
break
73+
74+
reveal_type(x) # revealed: int
75+
```

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use crate::semantic_index::place::{
3535
PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
3636
};
3737
use crate::semantic_index::predicate::{
38-
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, ScopedPredicateId,
39-
StarImportPlaceholderPredicate,
38+
PatternPredicate, PatternPredicateKind, Predicate, PredicateInner, PredicateNode,
39+
ScopedPredicateId, StarImportPlaceholderPredicate,
4040
};
4141
use crate::semantic_index::re_exports::exported_names;
4242
use crate::semantic_index::reachability_constraints::{
@@ -543,9 +543,14 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
543543

544544
fn build_predicate(&mut self, predicate_node: &ast::Expr) -> Predicate<'db> {
545545
let expression = self.add_standalone_expression(predicate_node);
546-
Predicate {
547-
node: PredicateNode::Expression(expression),
548-
is_positive: true,
546+
547+
if let Some(boolean_literal) = predicate_node.as_boolean_literal_expr() {
548+
Predicate::Always(boolean_literal.value)
549+
} else {
550+
Predicate::Inner(PredicateInner {
551+
node: PredicateNode::Expression(expression),
552+
is_positive: true,
553+
})
549554
}
550555
}
551556

@@ -617,6 +622,7 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
617622
let reachability_constraint = self
618623
.current_reachability_constraints_mut()
619624
.add_atom(predicate_id);
625+
620626
self.current_use_def_map_mut()
621627
.record_reachability_constraint(reachability_constraint);
622628
reachability_constraint
@@ -705,10 +711,10 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
705711
guard,
706712
countme::Count::default(),
707713
);
708-
let predicate = Predicate {
714+
let predicate = Predicate::Inner(PredicateInner {
709715
node: PredicateNode::Pattern(pattern_predicate),
710716
is_positive: true,
711-
};
717+
});
712718
self.record_narrowing_constraint(predicate);
713719
predicate
714720
}
@@ -1666,10 +1672,10 @@ impl<'ast> Visitor<'ast> for SemanticIndexBuilder<'_, 'ast> {
16661672
self.record_ambiguous_reachability();
16671673
self.visit_expr(guard);
16681674
let post_guard_eval = self.flow_snapshot();
1669-
let predicate = Predicate {
1675+
let predicate = Predicate::Inner(PredicateInner {
16701676
node: PredicateNode::Expression(guard_expr),
16711677
is_positive: true,
1672-
};
1678+
});
16731679
self.record_negated_narrowing_constraint(predicate);
16741680
let match_success_guard_failure = self.flow_snapshot();
16751681
self.flow_restore(post_guard_eval);

crates/ty_python_semantic/src/semantic_index/predicate.rs

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,29 @@ use crate::semantic_index::place::{FileScopeId, ScopeId, ScopedPlaceId};
2121
#[derive(Ord, PartialOrd, get_size2::GetSize)]
2222
pub(crate) struct ScopedPredicateId;
2323

24+
impl ScopedPredicateId {
25+
/// A special ID that is used for an "always true" predicate.
26+
pub(crate) const ALWAYS_TRUE: ScopedPredicateId =
27+
ScopedPredicateId(std::num::NonZero::new(0xffff_ffff).unwrap());
28+
29+
/// A special ID that is used for an "always false" predicate.
30+
pub(crate) const ALWAYS_FALSE: ScopedPredicateId =
31+
ScopedPredicateId(std::num::NonZero::new(0xffff_fffe).unwrap());
32+
}
33+
2434
// A collection of predicates for a given scope.
25-
pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, Predicate<'db>>;
35+
pub(crate) type Predicates<'db> = IndexVec<ScopedPredicateId, PredicateInner<'db>>;
2636

2737
#[derive(Debug, Default)]
2838
pub(crate) struct PredicatesBuilder<'db> {
29-
predicates: IndexVec<ScopedPredicateId, Predicate<'db>>,
39+
predicates: IndexVec<ScopedPredicateId, PredicateInner<'db>>,
3040
}
3141

3242
impl<'db> PredicatesBuilder<'db> {
3343
/// Adds a predicate. Note that we do not deduplicate predicates. If you add a `Predicate`
3444
/// more than once, you will get distinct `ScopedPredicateId`s for each one. (This lets you
3545
/// model predicates that might evaluate to different values at different points of execution.)
36-
pub(crate) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
46+
pub(crate) fn add_predicate(&mut self, predicate: PredicateInner<'db>) -> ScopedPredicateId {
3747
self.predicates.push(predicate)
3848
}
3949

@@ -44,16 +54,27 @@ impl<'db> PredicatesBuilder<'db> {
4454
}
4555

4656
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
47-
pub(crate) struct Predicate<'db> {
57+
pub(crate) struct PredicateInner<'db> {
4858
pub(crate) node: PredicateNode<'db>,
4959
pub(crate) is_positive: bool,
5060
}
5161

62+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, salsa::Update, get_size2::GetSize)]
63+
pub(crate) enum Predicate<'db> {
64+
Always(bool),
65+
Inner(PredicateInner<'db>),
66+
}
67+
5268
impl Predicate<'_> {
5369
pub(crate) fn negated(self) -> Self {
54-
Self {
55-
node: self.node,
56-
is_positive: !self.is_positive,
70+
match self {
71+
Predicate::Always(value) => Predicate::Always(!value),
72+
Predicate::Inner(PredicateInner { node, is_positive }) => {
73+
Predicate::Inner(PredicateInner {
74+
node,
75+
is_positive: !is_positive,
76+
})
77+
}
5778
}
5879
}
5980
}
@@ -171,9 +192,9 @@ impl<'db> StarImportPlaceholderPredicate<'db> {
171192

172193
impl<'db> From<StarImportPlaceholderPredicate<'db>> for Predicate<'db> {
173194
fn from(predicate: StarImportPlaceholderPredicate<'db>) -> Self {
174-
Predicate {
195+
Predicate::Inner(PredicateInner {
175196
node: PredicateNode::StarImportPlaceholder(predicate),
176197
is_positive: true,
177-
}
198+
})
178199
}
179200
}

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,8 @@ use crate::place::{RequiresExplicitReExport, imported_symbol};
204204
use crate::semantic_index::expression::Expression;
205205
use crate::semantic_index::place_table;
206206
use crate::semantic_index::predicate::{
207-
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode, Predicates, ScopedPredicateId,
207+
PatternPredicate, PatternPredicateKind, PredicateInner, PredicateNode, Predicates,
208+
ScopedPredicateId,
208209
};
209210
use crate::types::{Truthiness, Type, infer_expression_type};
210211

@@ -388,12 +389,18 @@ impl ReachabilityConstraintsBuilder {
388389
&mut self,
389390
predicate: ScopedPredicateId,
390391
) -> ScopedReachabilityConstraintId {
391-
self.add_interior(InteriorNode {
392-
atom: predicate,
393-
if_true: ALWAYS_TRUE,
394-
if_ambiguous: AMBIGUOUS,
395-
if_false: ALWAYS_FALSE,
396-
})
392+
if predicate == ScopedPredicateId::ALWAYS_FALSE {
393+
ScopedReachabilityConstraintId::ALWAYS_FALSE
394+
} else if predicate == ScopedPredicateId::ALWAYS_TRUE {
395+
ScopedReachabilityConstraintId::ALWAYS_TRUE
396+
} else {
397+
self.add_interior(InteriorNode {
398+
atom: predicate,
399+
if_true: ALWAYS_TRUE,
400+
if_ambiguous: AMBIGUOUS,
401+
if_false: ALWAYS_FALSE,
402+
})
403+
}
397404
}
398405

399406
/// Adds a new reachability constraint that is the ternary NOT of an existing one.
@@ -672,7 +679,7 @@ impl ReachabilityConstraints {
672679
}
673680
}
674681

675-
fn analyze_single(db: &dyn Db, predicate: &Predicate) -> Truthiness {
682+
fn analyze_single(db: &dyn Db, predicate: &PredicateInner) -> Truthiness {
676683
match predicate.node {
677684
PredicateNode::Expression(test_expr) => {
678685
let ty = infer_expression_type(db, test_expr);

crates/ty_python_semantic/src/semantic_index/use_def.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ use crate::semantic_index::place::{
247247
FileScopeId, PlaceExpr, PlaceExprWithFlags, ScopeKind, ScopedPlaceId,
248248
};
249249
use crate::semantic_index::predicate::{
250-
Predicate, Predicates, PredicatesBuilder, ScopedPredicateId, StarImportPlaceholderPredicate,
250+
Predicate, PredicateInner, Predicates, PredicatesBuilder, ScopedPredicateId,
251+
StarImportPlaceholderPredicate,
251252
};
252253
use crate::semantic_index::reachability_constraints::{
253254
ReachabilityConstraints, ReachabilityConstraintsBuilder, ScopedReachabilityConstraintId,
@@ -610,7 +611,7 @@ pub(crate) struct ConstraintsIterator<'map, 'db> {
610611
}
611612

612613
impl<'db> Iterator for ConstraintsIterator<'_, 'db> {
613-
type Item = Predicate<'db>;
614+
type Item = PredicateInner<'db>;
614615

615616
fn next(&mut self) -> Option<Self::Item> {
616617
self.constraint_ids
@@ -806,10 +807,21 @@ impl<'db> UseDefMapBuilder<'db> {
806807
}
807808

808809
pub(super) fn add_predicate(&mut self, predicate: Predicate<'db>) -> ScopedPredicateId {
809-
self.predicates.add_predicate(predicate)
810+
match predicate {
811+
Predicate::Inner(predicate) => self.predicates.add_predicate(predicate),
812+
Predicate::Always(true) => ScopedPredicateId::ALWAYS_TRUE,
813+
Predicate::Always(false) => ScopedPredicateId::ALWAYS_FALSE,
814+
}
810815
}
811816

812817
pub(super) fn record_narrowing_constraint(&mut self, predicate: ScopedPredicateId) {
818+
if predicate == ScopedPredicateId::ALWAYS_TRUE
819+
|| predicate == ScopedPredicateId::ALWAYS_FALSE
820+
{
821+
// No need to record a narrowing constraint for `True` or `False`.
822+
return;
823+
}
824+
813825
let narrowing_constraint = predicate.into();
814826
for state in &mut self.place_states {
815827
state

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::semantic_index::expression::Expression;
44
use crate::semantic_index::place::{PlaceExpr, PlaceTable, ScopeId, ScopedPlaceId};
55
use crate::semantic_index::place_table;
66
use crate::semantic_index::predicate::{
7-
PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
7+
PatternPredicate, PatternPredicateKind, PredicateInner, PredicateNode,
88
};
99
use crate::types::function::KnownFunction;
1010
use crate::types::infer::infer_same_file_expression_type;
@@ -42,7 +42,7 @@ use super::UnionType;
4242
/// constraint is applied to that symbol, so we'd just return `None`.
4343
pub(crate) fn infer_narrowing_constraint<'db>(
4444
db: &'db dyn Db,
45-
predicate: Predicate<'db>,
45+
predicate: PredicateInner<'db>,
4646
place: ScopedPlaceId,
4747
) -> Option<Type<'db>> {
4848
let constraints = match predicate.node {

0 commit comments

Comments
 (0)