Skip to content

Commit 10ccc4a

Browse files
authored
feat(flotilla): no shuffle for hash join if conditions are met (#5135)
## Changes Made Adds logic to elide shuffles for either side of a hash join when input num partitions != num partitions for join or input is not hash partitioned on the join key. Same logic as https://github.com/Eventual-Inc/Daft/blob/main/src/daft-physical-plan/src/physical_planner/translate.rs#L1441 ## Related Issues Closes #5127 ## Checklist - [ ] Documented in API Docs (if applicable) - [ ] Documented in User Guide (if applicable) - [ ] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [ ] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review)
1 parent 889a3dd commit 10ccc4a

File tree

2 files changed

+73
-60
lines changed

2 files changed

+73
-60
lines changed

src/daft-distributed/src/pipeline_node/join/hash_join.rs

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
1-
use std::{cmp::max, sync::Arc};
1+
use std::sync::Arc;
22

3-
use common_daft_config::DaftExecutionConfig;
43
use common_display::{DisplayLevel, tree::TreeDisplay};
54
use daft_dsl::expr::bound_expr::BoundExpr;
65
use daft_local_plan::LocalPhysicalPlan;
7-
use daft_logical_plan::{
8-
ClusteringSpec, JoinType, partitioning::HashClusteringConfig, stats::StatsState,
9-
};
6+
use daft_logical_plan::{JoinType, partitioning::HashClusteringConfig, stats::StatsState};
107
use daft_schema::schema::SchemaRef;
118
use futures::StreamExt;
129

@@ -197,31 +194,3 @@ impl DistributedPipelineNode for HashJoinNode {
197194
self
198195
}
199196
}
200-
201-
pub(crate) fn gen_num_partitions(
202-
left_spec: &ClusteringSpec,
203-
right_spec: &ClusteringSpec,
204-
cfg: &DaftExecutionConfig,
205-
) -> usize {
206-
let is_left_hash_partitioned = matches!(left_spec, ClusteringSpec::Hash(_));
207-
let is_right_hash_partitioned = matches!(right_spec, ClusteringSpec::Hash(_));
208-
let num_left_partitions = left_spec.num_partitions();
209-
let num_right_partitions = right_spec.num_partitions();
210-
211-
match (
212-
is_left_hash_partitioned,
213-
is_right_hash_partitioned,
214-
num_left_partitions,
215-
num_right_partitions,
216-
) {
217-
(true, true, a, b) | (false, false, a, b) => max(a, b),
218-
(_, _, 1, x) | (_, _, x, 1) => x,
219-
(true, false, a, b) if (a as f64) >= (b as f64) * cfg.hash_join_partition_size_leniency => {
220-
a
221-
}
222-
(false, true, a, b) if (b as f64) >= (a as f64) * cfg.hash_join_partition_size_leniency => {
223-
b
224-
}
225-
(_, _, a, b) => max(a, b),
226-
}
227-
}

src/daft-distributed/src/pipeline_node/join/translate_join.rs

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
use std::sync::Arc;
1+
use std::{cmp::max, sync::Arc};
22

33
use common_error::DaftResult;
4-
use daft_dsl::{ExprRef, expr::bound_expr::BoundExpr};
4+
use daft_dsl::{ExprRef, expr::bound_expr::BoundExpr, is_partition_compatible};
55
use daft_logical_plan::{
6-
JoinStrategy, JoinType,
6+
ClusteringSpec, JoinStrategy, JoinType,
77
ops::Join,
88
partitioning::{HashRepartitionConfig, RepartitionSpec},
99
stats::ApproxStats,
1010
};
1111
use daft_schema::schema::SchemaRef;
1212

13-
use super::hash_join::gen_num_partitions;
1413
use crate::pipeline_node::{
1514
DistributedPipelineNode, NodeID,
1615
join::{BroadcastJoinNode, HashJoinNode},
@@ -75,31 +74,76 @@ impl LogicalPlanToPipelineNodeTranslator {
7574
join_type: JoinType,
7675
output_schema: SchemaRef,
7776
) -> DaftResult<Arc<dyn DistributedPipelineNode>> {
78-
let num_partitions = gen_num_partitions(
79-
left.config().clustering_spec.as_ref(),
80-
right.config().clustering_spec.as_ref(),
81-
self.stage_config.config.as_ref(),
82-
);
77+
let left_spec = left.config().clustering_spec.as_ref();
78+
let right_spec = right.config().clustering_spec.as_ref();
8379

84-
let left = self.gen_shuffle_node(
85-
logical_node_id,
86-
RepartitionSpec::Hash(HashRepartitionConfig::new(
87-
Some(num_partitions),
88-
left_on.iter().map(|e| e.clone().into()).collect(),
89-
)),
90-
left.config().schema.clone(),
91-
left,
92-
)?;
80+
let is_left_hash_partitioned = matches!(left_spec, ClusteringSpec::Hash(..))
81+
&& is_partition_compatible(
82+
&left_spec.partition_by(),
83+
left_on.iter().map(|e| e.inner()),
84+
);
85+
let is_right_hash_partitioned = matches!(right_spec, ClusteringSpec::Hash(..))
86+
&& is_partition_compatible(
87+
&right_spec.partition_by(),
88+
right_on.iter().map(|e| e.inner()),
89+
);
90+
let num_left_partitions = left_spec.num_partitions();
91+
let num_right_partitions = right_spec.num_partitions();
9392

94-
let right = self.gen_shuffle_node(
95-
logical_node_id,
96-
RepartitionSpec::Hash(HashRepartitionConfig::new(
97-
Some(num_partitions),
98-
right_on.iter().map(|e| e.clone().into()).collect(),
99-
)),
100-
right.config().schema.clone(),
101-
right,
102-
)?;
93+
let num_partitions = match (
94+
is_left_hash_partitioned,
95+
is_right_hash_partitioned,
96+
num_left_partitions,
97+
num_right_partitions,
98+
) {
99+
(true, true, a, b) | (false, false, a, b) => max(a, b),
100+
(_, _, 1, x) | (_, _, x, 1) => x,
101+
(true, false, a, b)
102+
if (a as f64)
103+
>= (b as f64) * self.stage_config.config.hash_join_partition_size_leniency =>
104+
{
105+
a
106+
}
107+
(false, true, a, b)
108+
if (b as f64)
109+
>= (a as f64) * self.stage_config.config.hash_join_partition_size_leniency =>
110+
{
111+
b
112+
}
113+
(_, _, a, b) => max(a, b),
114+
};
115+
116+
let left = if num_left_partitions != num_partitions
117+
|| (num_partitions > 1 && !is_left_hash_partitioned)
118+
{
119+
self.gen_shuffle_node(
120+
logical_node_id,
121+
RepartitionSpec::Hash(HashRepartitionConfig::new(
122+
Some(num_partitions),
123+
left_on.iter().map(|e| e.clone().into()).collect(),
124+
)),
125+
left.config().schema.clone(),
126+
left,
127+
)?
128+
} else {
129+
left
130+
};
131+
132+
let right = if num_right_partitions != num_partitions
133+
|| (num_partitions > 1 && !is_right_hash_partitioned)
134+
{
135+
self.gen_shuffle_node(
136+
logical_node_id,
137+
RepartitionSpec::Hash(HashRepartitionConfig::new(
138+
Some(num_partitions),
139+
right_on.iter().map(|e| e.clone().into()).collect(),
140+
)),
141+
right.config().schema.clone(),
142+
right,
143+
)?
144+
} else {
145+
right
146+
};
103147

104148
Ok(HashJoinNode::new(
105149
self.get_next_pipeline_node_id(),

0 commit comments

Comments
 (0)