Skip to content

Commit 6c90e2c

Browse files
Add JoinContext with JoinLeftData to TaskContext in HashJoinExec (#300)
* Add JoinContext with JoinLeftData to TaskContext in HashJoinExec * Expose random state as const * re-export ahash::RandomState * JoinContext default impl * Add debug log when setting join left data
1 parent 2fb45f8 commit 6c90e2c

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

datafusion/physical-plan/src/joins/hash_join.rs

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,26 @@ use arrow_buffer::BooleanBuffer;
7777
use datafusion_expr::Operator;
7878
use datafusion_physical_expr_common::datum::compare_op_for_nested;
7979
use futures::{ready, Stream, StreamExt, TryStreamExt};
80+
use log::debug;
8081
use parking_lot::Mutex;
8182

83+
pub const RANDOM_STATE: RandomState = RandomState::with_seeds(0, 0, 0, 0);
84+
85+
#[derive(Default)]
86+
pub struct JoinContext {
87+
build_state: Mutex<Option<Arc<JoinLeftData>>>,
88+
}
89+
90+
impl JoinContext {
91+
pub fn set_build_state(&self, state: Arc<JoinLeftData>) {
92+
self.build_state.lock().replace(state);
93+
}
94+
95+
pub fn get_build_state(&self) -> Option<Arc<JoinLeftData>> {
96+
self.build_state.lock().clone()
97+
}
98+
}
99+
82100
pub struct SharedJoinState {
83101
state_impl: Arc<dyn SharedJoinStateImpl>,
84102
}
@@ -128,7 +146,7 @@ pub trait SharedJoinStateImpl: Send + Sync + 'static {
128146
type SharedBitmapBuilder = Mutex<BooleanBufferBuilder>;
129147

130148
/// HashTable and input data for the left (build side) of a join
131-
struct JoinLeftData {
149+
pub struct JoinLeftData {
132150
/// The hash table with indices into `batch`
133151
hash_map: JoinHashMap,
134152
/// The input rows for the build side
@@ -165,6 +183,10 @@ impl JoinLeftData {
165183
}
166184
}
167185

186+
pub fn contains_hash(&self, hash: u64) -> bool {
187+
self.hash_map.contains_hash(hash)
188+
}
189+
168190
/// return a reference to the hash map
169191
fn hash_map(&self) -> &JoinHashMap {
170192
&self.hash_map
@@ -768,6 +790,7 @@ impl ExecutionPlan for HashJoinExec {
768790

769791
let distributed_state =
770792
context.session_config().get_extension::<SharedJoinState>();
793+
let join_context = context.session_config().get_extension::<JoinContext>();
771794

772795
let join_metrics = BuildProbeJoinMetrics::new(partition, &self.metrics);
773796
let left_fut = match self.mode {
@@ -855,6 +878,7 @@ impl ExecutionPlan for HashJoinExec {
855878
batch_size,
856879
hashes_buffer: vec![],
857880
right_side_ordered: self.right.output_ordering().is_some(),
881+
join_context,
858882
}))
859883
}
860884

@@ -1187,6 +1211,7 @@ struct HashJoinStream {
11871211
hashes_buffer: Vec<u64>,
11881212
/// Specifies whether the right side has an ordering to potentially preserve
11891213
right_side_ordered: bool,
1214+
join_context: Option<Arc<JoinContext>>,
11901215
}
11911216

11921217
impl RecordBatchStream for HashJoinStream {
@@ -1399,6 +1424,11 @@ impl HashJoinStream {
13991424
.get_shared(cx))?;
14001425
build_timer.done();
14011426

1427+
if let Some(ctx) = self.join_context.as_ref() {
1428+
debug!("setting join left data in join context");
1429+
ctx.set_build_state(Arc::clone(&left_data));
1430+
}
1431+
14021432
self.state = HashJoinStreamState::FetchProbeBatch;
14031433
self.build_side = BuildSide::Ready(BuildSideReadyState { left_data });
14041434

datafusion/physical-plan/src/joins/mod.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
2020
pub use cross_join::CrossJoinExec;
2121
pub use hash_join::{
22-
HashJoinExec, SharedJoinState, SharedJoinStateImpl, SharedProbeState,
22+
HashJoinExec, JoinContext, JoinLeftData, SharedJoinState, SharedJoinStateImpl,
23+
SharedProbeState, RANDOM_STATE,
2324
};
2425
pub use nested_loop_join::NestedLoopJoinExec;
2526
// Note: SortMergeJoin is not used in plans yet
@@ -33,6 +34,8 @@ mod stream_join_utils;
3334
mod symmetric_hash_join;
3435
pub mod utils;
3536

37+
pub type RandomState = ahash::RandomState;
38+
3639
#[cfg(test)]
3740
pub mod test_utils;
3841

datafusion/physical-plan/src/joins/utils.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ impl JoinHashMap {
139139
next: vec![0; capacity],
140140
}
141141
}
142+
143+
pub fn contains_hash(&self, hash: u64) -> bool {
144+
self.map.find(hash, |(h, _)| *h == hash).is_some()
145+
}
142146
}
143147

144148
// Type of offsets for obtaining indices from JoinHashMap.

0 commit comments

Comments
 (0)