Skip to content

Commit

Permalink
Move filtered SMJ Left Anti filtered join out of join_partial phase (
Browse files Browse the repository at this point in the history
…apache#13111)

* Move filtered SMJ Left Anti filtered join out of `join_partial` phase
  • Loading branch information
comphead authored Oct 26, 2024
1 parent 62b063c commit 146f16a
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 220 deletions.
6 changes: 2 additions & 4 deletions datafusion/core/tests/fuzz_cases/join_fuzz.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion::physical_plan::joins::{
};
use datafusion::physical_plan::memory::MemoryExec;

use crate::fuzz_cases::join_fuzz::JoinTestType::NljHj;
use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch_with_seed;

Expand Down Expand Up @@ -223,17 +224,14 @@ async fn test_anti_join_1k() {
}

#[tokio::test]
// flaky for HjSmj case, giving 1 rows difference sometimes
// https://github.com/apache/datafusion/issues/11555
#[ignore]
async fn test_anti_join_1k_filtered() {
JoinFuzzTestCase::new(
make_staggered_batches(1000),
make_staggered_batches(1000),
JoinType::LeftAnti,
Some(Box::new(col_lt_col_filter)),
)
.run_test(&[JoinTestType::NljHj], false)
.run_test(&[JoinTestType::HjSmj, NljHj], false)
.await
}

Expand Down
245 changes: 227 additions & 18 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,32 @@ fn get_corrected_filter_mask(

Some(corrected_mask.finish())
}
JoinType::LeftAnti => {
for i in 0..row_indices_length {
let last_index =
last_index_for_row(i, row_indices, batch_ids, row_indices_length);

if filter_mask.value(i) {
seen_true = true;
}

if last_index {
if !seen_true {
corrected_mask.append_value(true);
} else {
corrected_mask.append_null();
}

seen_true = false;
} else {
corrected_mask.append_null();
}
}

let null_matched = expected_size - corrected_mask.len();
corrected_mask.extend(vec![Some(true); null_matched]);
Some(corrected_mask.finish())
}
// Only outer joins needs to keep track of processed rows and apply corrected filter mask
_ => None,
}
Expand Down Expand Up @@ -835,15 +861,18 @@ impl Stream for SMJStream {
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
self.freeze_all()?;

if !self.output_record_batches.batches.is_empty()
&& self.buffered_data.scanning_finished()
{
let out_batch = self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(out_batch)));
let out_filtered_batch =
self.filter_joined_batch()?;
return Poll::Ready(Some(Ok(
out_filtered_batch,
)));
}
}

Expand Down Expand Up @@ -907,15 +936,17 @@ impl Stream for SMJStream {
// because target output batch size can be hit in the middle of
// filtering causing the filtering to be incomplete and causing
// correctness issues
let record_batch = if !(self.filter.is_some()
if self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
)) {
record_batch
} else {
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
continue;
};
}

return Poll::Ready(Some(Ok(record_batch)));
}
Expand All @@ -929,7 +960,10 @@ impl Stream for SMJStream {
if self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
)
{
let out = self.filter_joined_batch()?;
Expand Down Expand Up @@ -1273,11 +1307,7 @@ impl SMJStream {
};

if matches!(self.join_type, JoinType::LeftAnti) && self.filter.is_some() {
join_streamed = !self
.streamed_batch
.join_filter_matched_idxs
.contains(&(self.streamed_batch.idx as u64))
&& !self.streamed_joined;
join_streamed = !self.streamed_joined;
join_buffered = join_streamed;
}
}
Expand Down Expand Up @@ -1519,7 +1549,10 @@ impl SMJStream {
// Push the filtered batch which contains rows passing join filter to the output
if matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
) {
self.output_record_batches
.batches
Expand Down Expand Up @@ -1654,7 +1687,10 @@ impl SMJStream {
if !(self.filter.is_some()
&& matches!(
self.join_type,
JoinType::Left | JoinType::LeftSemi | JoinType::Right
JoinType::Left
| JoinType::LeftSemi
| JoinType::Right
| JoinType::LeftAnti
))
{
self.output_record_batches.batches.clear();
Expand Down Expand Up @@ -1727,7 +1763,7 @@ impl SMJStream {
&self.schema,
&[filtered_record_batch, null_joined_streamed_batch],
)?;
} else if matches!(self.join_type, JoinType::LeftSemi) {
} else if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
let output_column_indices = (0..streamed_columns_length).collect::<Vec<_>>();
filtered_record_batch =
filtered_record_batch.project(&output_column_indices)?;
Expand Down Expand Up @@ -3349,6 +3385,7 @@ mod tests {
batch_ids: vec![],
};

// Insert already prejoined non-filtered rows
batches.batches.push(RecordBatch::try_new(
Arc::clone(&schema),
vec![
Expand Down Expand Up @@ -3835,6 +3872,178 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_left_anti_join_filtered_mask() -> Result<()> {
let mut joined_batches = build_joined_record_batches()?;
let schema = joined_batches.batches.first().unwrap().schema();

let output = concat_batches(&schema, &joined_batches.batches)?;
let out_mask = joined_batches.filter_mask.finish();
let out_indices = joined_batches.row_indices.finish();

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0]),
&[0usize],
&BooleanArray::from(vec![true]),
1
)
.unwrap(),
BooleanArray::from(vec![None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0]),
&[0usize],
&BooleanArray::from(vec![false]),
1
)
.unwrap(),
BooleanArray::from(vec![Some(true)])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0]),
&[0usize; 2],
&BooleanArray::from(vec![true, true]),
2
)
.unwrap(),
BooleanArray::from(vec![None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![true, true, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![true, false, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, false, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, true, true]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, None])
);

assert_eq!(
get_corrected_filter_mask(
LeftAnti,
&UInt64Array::from(vec![0, 0, 0]),
&[0usize; 3],
&BooleanArray::from(vec![false, false, false]),
3
)
.unwrap(),
BooleanArray::from(vec![None, None, Some(true)])
);

let corrected_mask = get_corrected_filter_mask(
LeftAnti,
&out_indices,
&joined_batches.batch_ids,
&out_mask,
output.num_rows(),
)
.unwrap();

assert_eq!(
corrected_mask,
BooleanArray::from(vec![
None,
None,
None,
None,
None,
Some(true),
None,
Some(true)
])
);

let filtered_rb = filter_record_batch(&output, &corrected_mask)?;

assert_batches_eq!(
&[
"+---+----+---+----+",
"| a | b | x | y |",
"+---+----+---+----+",
"| 1 | 13 | 1 | 12 |",
"| 1 | 14 | 1 | 11 |",
"+---+----+---+----+",
],
&[filtered_rb]
);

// output null rows
let null_mask = arrow::compute::not(&corrected_mask)?;
assert_eq!(
null_mask,
BooleanArray::from(vec![
None,
None,
None,
None,
None,
Some(false),
None,
Some(false),
])
);

let null_joined_batch = filter_record_batch(&output, &null_mask)?;

assert_batches_eq!(
&[
"+---+---+---+---+",
"| a | b | x | y |",
"+---+---+---+---+",
"+---+---+---+---+",
],
&[null_joined_batch]
);
Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
Loading

0 comments on commit 146f16a

Please sign in to comment.