|
1 |
| -use std::sync::Arc; |
| 1 | +use std::{cmp::max, sync::Arc}; |
2 | 2 |
|
3 | 3 | use common_error::DaftResult;
|
4 |
| -use daft_dsl::{ExprRef, expr::bound_expr::BoundExpr}; |
| 4 | +use daft_dsl::{ExprRef, expr::bound_expr::BoundExpr, is_partition_compatible}; |
5 | 5 | use daft_logical_plan::{
|
6 |
| - JoinStrategy, JoinType, |
| 6 | + ClusteringSpec, JoinStrategy, JoinType, |
7 | 7 | ops::Join,
|
8 | 8 | partitioning::{HashRepartitionConfig, RepartitionSpec},
|
9 | 9 | stats::ApproxStats,
|
10 | 10 | };
|
11 | 11 | use daft_schema::schema::SchemaRef;
|
12 | 12 |
|
13 |
| -use super::hash_join::gen_num_partitions; |
14 | 13 | use crate::pipeline_node::{
|
15 | 14 | DistributedPipelineNode, NodeID,
|
16 | 15 | join::{BroadcastJoinNode, HashJoinNode},
|
@@ -75,31 +74,76 @@ impl LogicalPlanToPipelineNodeTranslator {
|
75 | 74 | join_type: JoinType,
|
76 | 75 | output_schema: SchemaRef,
|
77 | 76 | ) -> DaftResult<Arc<dyn DistributedPipelineNode>> {
|
78 |
| - let num_partitions = gen_num_partitions( |
79 |
| - left.config().clustering_spec.as_ref(), |
80 |
| - right.config().clustering_spec.as_ref(), |
81 |
| - self.stage_config.config.as_ref(), |
82 |
| - ); |
| 77 | + let left_spec = left.config().clustering_spec.as_ref(); |
| 78 | + let right_spec = right.config().clustering_spec.as_ref(); |
83 | 79 |
|
84 |
| - let left = self.gen_shuffle_node( |
85 |
| - logical_node_id, |
86 |
| - RepartitionSpec::Hash(HashRepartitionConfig::new( |
87 |
| - Some(num_partitions), |
88 |
| - left_on.iter().map(|e| e.clone().into()).collect(), |
89 |
| - )), |
90 |
| - left.config().schema.clone(), |
91 |
| - left, |
92 |
| - )?; |
| 80 | + let is_left_hash_partitioned = matches!(left_spec, ClusteringSpec::Hash(..)) |
| 81 | + && is_partition_compatible( |
| 82 | + &left_spec.partition_by(), |
| 83 | + left_on.iter().map(|e| e.inner()), |
| 84 | + ); |
| 85 | + let is_right_hash_partitioned = matches!(right_spec, ClusteringSpec::Hash(..)) |
| 86 | + && is_partition_compatible( |
| 87 | + &right_spec.partition_by(), |
| 88 | + right_on.iter().map(|e| e.inner()), |
| 89 | + ); |
| 90 | + let num_left_partitions = left_spec.num_partitions(); |
| 91 | + let num_right_partitions = right_spec.num_partitions(); |
93 | 92 |
|
94 |
| - let right = self.gen_shuffle_node( |
95 |
| - logical_node_id, |
96 |
| - RepartitionSpec::Hash(HashRepartitionConfig::new( |
97 |
| - Some(num_partitions), |
98 |
| - right_on.iter().map(|e| e.clone().into()).collect(), |
99 |
| - )), |
100 |
| - right.config().schema.clone(), |
101 |
| - right, |
102 |
| - )?; |
| 93 | + let num_partitions = match ( |
| 94 | + is_left_hash_partitioned, |
| 95 | + is_right_hash_partitioned, |
| 96 | + num_left_partitions, |
| 97 | + num_right_partitions, |
| 98 | + ) { |
| 99 | + (true, true, a, b) | (false, false, a, b) => max(a, b), |
| 100 | + (_, _, 1, x) | (_, _, x, 1) => x, |
| 101 | + (true, false, a, b) |
| 102 | + if (a as f64) |
| 103 | + >= (b as f64) * self.stage_config.config.hash_join_partition_size_leniency => |
| 104 | + { |
| 105 | + a |
| 106 | + } |
| 107 | + (false, true, a, b) |
| 108 | + if (b as f64) |
| 109 | + >= (a as f64) * self.stage_config.config.hash_join_partition_size_leniency => |
| 110 | + { |
| 111 | + b |
| 112 | + } |
| 113 | + (_, _, a, b) => max(a, b), |
| 114 | + }; |
| 115 | + |
| 116 | + let left = if num_left_partitions != num_partitions |
| 117 | + || (num_partitions > 1 && !is_left_hash_partitioned) |
| 118 | + { |
| 119 | + self.gen_shuffle_node( |
| 120 | + logical_node_id, |
| 121 | + RepartitionSpec::Hash(HashRepartitionConfig::new( |
| 122 | + Some(num_partitions), |
| 123 | + left_on.iter().map(|e| e.clone().into()).collect(), |
| 124 | + )), |
| 125 | + left.config().schema.clone(), |
| 126 | + left, |
| 127 | + )? |
| 128 | + } else { |
| 129 | + left |
| 130 | + }; |
| 131 | + |
| 132 | + let right = if num_right_partitions != num_partitions |
| 133 | + || (num_partitions > 1 && !is_right_hash_partitioned) |
| 134 | + { |
| 135 | + self.gen_shuffle_node( |
| 136 | + logical_node_id, |
| 137 | + RepartitionSpec::Hash(HashRepartitionConfig::new( |
| 138 | + Some(num_partitions), |
| 139 | + right_on.iter().map(|e| e.clone().into()).collect(), |
| 140 | + )), |
| 141 | + right.config().schema.clone(), |
| 142 | + right, |
| 143 | + )? |
| 144 | + } else { |
| 145 | + right |
| 146 | + }; |
103 | 147 |
|
104 | 148 | Ok(HashJoinNode::new(
|
105 | 149 | self.get_next_pipeline_node_id(),
|
|
0 commit comments