Skip to content

Commit af0e8a9

Browse files
jayzhan211alamb
andauthored
Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) (#8849)
* chkp Signed-off-by: jayzhan211 <[email protected]> * chkp Signed-off-by: jayzhan211 <[email protected]> * draft Signed-off-by: jayzhan211 <[email protected]> * iter done Signed-off-by: jayzhan211 <[email protected]> * short string test Signed-off-by: jayzhan211 <[email protected]> * add test Signed-off-by: jayzhan211 <[email protected]> * remove unused Signed-off-by: jayzhan211 <[email protected]> * to_string directly Signed-off-by: jayzhan211 <[email protected]> * rewrite evaluate Signed-off-by: jayzhan211 <[email protected]> * return Vec<String> Signed-off-by: jayzhan211 <[email protected]> * fmt Signed-off-by: jayzhan211 <[email protected]> * add more queries Signed-off-by: jayzhan211 <[email protected]> * add group by query and rewrite evalute with state() Signed-off-by: jayzhan211 <[email protected]> * move evaluate back Signed-off-by: jayzhan211 <[email protected]> * upd test Signed-off-by: jayzhan211 <[email protected]> * add row sort Signed-off-by: jayzhan211 <[email protected]> * Update benchmarks/queries/clickbench/README.md * Rework set to avoid copies * Simplify offset construction * fmt * Improve comments * Improve comments * add fuzz test Signed-off-by: jayzhan211 <[email protected]> * Add support for LargeStringArray * refine fuzz test * Add tests for size accounting * Split into new module * Remove use of Mutex * revert changes * Use reference rather than owned ArrayRef --------- Signed-off-by: jayzhan211 <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent a57e270 commit af0e8a9

File tree

9 files changed

+792
-21
lines changed

9 files changed

+792
-21
lines changed

benchmarks/queries/clickbench/README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DIST
2929
FROM hits;
3030
```
3131

32-
3332
### Q1: Data Exploration
3433

3534
**Question**: "How many distinct "hit color", "browser country" and "language" are there in the dataset?"
@@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI
4241
FROM hits;
4342
```
4443

45-
### Q2: Top 10 anaylsis
44+
### Q2: Top 10 analysis
4645

4746
**Question**: "Find the top 10 "browser country" by number of distinct "social network"s,
4847
including the distinct counts of "hit color", "browser language",

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet
19+
20+
use std::sync::Arc;
21+
22+
use arrow::array::ArrayRef;
23+
use arrow::record_batch::RecordBatch;
24+
use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array};
25+
26+
use arrow_array::cast::AsArray;
27+
use datafusion::datasource::MemTable;
28+
use rand::rngs::StdRng;
29+
use rand::{thread_rng, Rng, SeedableRng};
30+
use std::collections::HashSet;
31+
use tokio::task::JoinSet;
32+
33+
use datafusion::prelude::{SessionConfig, SessionContext};
34+
use test_utils::stagger_batch;
35+
36+
#[tokio::test(flavor = "multi_thread")]
37+
async fn distinct_count_string_test() {
38+
// max length of generated strings
39+
let mut join_set = JoinSet::new();
40+
let mut rng = thread_rng();
41+
for null_pct in [0.0, 0.01, 0.1, 0.5] {
42+
for _ in 0..100 {
43+
let max_len = rng.gen_range(1..50);
44+
let num_strings = rng.gen_range(1..100);
45+
let num_distinct_strings = if num_strings > 1 {
46+
rng.gen_range(1..num_strings)
47+
} else {
48+
num_strings
49+
};
50+
let generator = BatchGenerator {
51+
max_len,
52+
num_strings,
53+
num_distinct_strings,
54+
null_pct,
55+
rng: StdRng::from_seed(rng.gen()),
56+
};
57+
join_set.spawn(async move { run_distinct_count_test(generator).await });
58+
}
59+
}
60+
while let Some(join_handle) = join_set.join_next().await {
61+
// propagate errors
62+
join_handle.unwrap();
63+
}
64+
}
65+
66+
/// Run COUNT DISTINCT using SQL and compare the result to computing the
67+
/// distinct count using HashSet<String>
68+
async fn run_distinct_count_test(mut generator: BatchGenerator) {
69+
let input = generator.make_input_batches();
70+
71+
let schema = input[0].schema();
72+
let session_config = SessionConfig::new().with_batch_size(50);
73+
let ctx = SessionContext::new_with_config(session_config);
74+
75+
// split input into two partitions
76+
let partition_len = input.len() / 2;
77+
let partitions = vec![
78+
input[0..partition_len].to_vec(),
79+
input[partition_len..].to_vec(),
80+
];
81+
82+
let provider = MemTable::try_new(schema, partitions).unwrap();
83+
ctx.register_table("t", Arc::new(provider)).unwrap();
84+
// input has two columns, a and b. The result is the number of distinct
85+
// values in each column.
86+
//
87+
// Note, we need at least two count distinct aggregates to trigger the
88+
// count distinct aggregate. Otherwise, the optimizer will rewrite the
89+
// `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)`
90+
let results = ctx
91+
.sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t")
92+
.await
93+
.unwrap()
94+
.collect()
95+
.await
96+
.unwrap();
97+
98+
// get all the strings from the first column of the result (distinct a)
99+
let expected_a = extract_distinct_strings::<i32>(&input, 0).len();
100+
let result_a = extract_i64(&results, 0);
101+
assert_eq!(expected_a, result_a);
102+
103+
// get all the strings from the second column of the result (distinct b(
104+
let expected_b = extract_distinct_strings::<i64>(&input, 1).len();
105+
let result_b = extract_i64(&results, 1);
106+
assert_eq!(expected_b, result_b);
107+
}
108+
109+
/// Return all (non null) distinct strings from column col_idx
110+
fn extract_distinct_strings<O: OffsetSizeTrait>(
111+
results: &[RecordBatch],
112+
col_idx: usize,
113+
) -> Vec<String> {
114+
results
115+
.iter()
116+
.flat_map(|batch| {
117+
let array = batch.column(col_idx).as_string::<O>();
118+
// remove nulls via 'flatten'
119+
array.iter().flatten().map(|s| s.to_string())
120+
})
121+
.collect::<HashSet<_>>()
122+
.into_iter()
123+
.collect()
124+
}
125+
126+
// extract the value from the Int64 column in col_idx in batch and return
127+
// it as a usize
128+
fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize {
129+
assert_eq!(results.len(), 1);
130+
let array = results[0]
131+
.column(col_idx)
132+
.as_any()
133+
.downcast_ref::<arrow::array::Int64Array>()
134+
.unwrap();
135+
assert_eq!(array.len(), 1);
136+
assert!(!array.is_null(0));
137+
array.value(0).try_into().unwrap()
138+
}
139+
140+
struct BatchGenerator {
141+
//// The maximum length of the strings
142+
max_len: usize,
143+
/// the total number of strings in the output
144+
num_strings: usize,
145+
/// The number of distinct strings in the columns
146+
num_distinct_strings: usize,
147+
/// The percentage of nulls in the columns
148+
null_pct: f64,
149+
/// Random number generator
150+
rng: StdRng,
151+
}
152+
153+
impl BatchGenerator {
154+
/// Make batches of random strings with a random length columns "a" and "b":
155+
///
156+
/// * "a" is a StringArray
157+
/// * "b" is a LargeStringArray
158+
fn make_input_batches(&mut self) -> Vec<RecordBatch> {
159+
// use a random number generator to pick a random sized output
160+
161+
let batch = RecordBatch::try_from_iter(vec![
162+
("a", self.gen_data::<i32>()),
163+
("b", self.gen_data::<i64>()),
164+
])
165+
.unwrap();
166+
167+
stagger_batch(batch)
168+
}
169+
170+
/// Creates a StringArray or LargeStringArray with random strings according
171+
/// to the parameters of the BatchGenerator
172+
fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
173+
// table of strings from which to draw
174+
let distinct_strings: GenericStringArray<O> = (0..self.num_distinct_strings)
175+
.map(|_| Some(random_string(&mut self.rng, self.max_len)))
176+
.collect();
177+
178+
// pick num_strings randomly from the distinct string table
179+
let indicies: UInt32Array = (0..self.num_strings)
180+
.map(|_| {
181+
if self.rng.gen::<f64>() < self.null_pct {
182+
None
183+
} else if self.num_distinct_strings > 1 {
184+
let range = 1..(self.num_distinct_strings as u32);
185+
Some(self.rng.gen_range(range))
186+
} else {
187+
Some(0)
188+
}
189+
})
190+
.collect();
191+
192+
let options = None;
193+
arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
194+
}
195+
}
196+
197+
/// Return a string of random characters of length 1..=max_len
198+
fn random_string(rng: &mut StdRng, max_len: usize) -> String {
199+
// pick characters at random (not just ascii)
200+
match max_len {
201+
0 => "".to_string(),
202+
1 => String::from(rng.gen::<char>()),
203+
_ => {
204+
let len = rng.gen_range(1..=max_len);
205+
rng.sample_iter::<char, _>(rand::distributions::Standard)
206+
.take(len)
207+
.map(char::from)
208+
.collect::<String>()
209+
}
210+
}
211+
}

datafusion/core/tests/fuzz_cases/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
// under the License.
1717

1818
mod aggregate_fuzz;
19+
mod distinct_count_string_fuzz;
1920
mod join_fuzz;
2021
mod merge_fuzz;
2122
mod sort_fuzz;

datafusion/physical-expr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
5454
blake3 = { version = "1.0", optional = true }
5555
chrono = { workspace = true }
5656
datafusion-common = { workspace = true }
57+
datafusion-execution = { workspace = true }
5758
datafusion-expr = { workspace = true }
5859
half = { version = "2.1", default-features = false }
5960
hashbrown = { version = "0.14", features = ["raw"] }

datafusion/physical-expr/src/aggregate/count_distinct.rs renamed to datafusion/physical-expr/src/aggregate/count_distinct/mod.rs

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,34 +15,37 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::datatypes::{DataType, Field, TimeUnit};
19-
use arrow_array::types::{
20-
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
21-
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
22-
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
23-
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
24-
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
25-
};
26-
use arrow_array::PrimitiveArray;
18+
mod strings;
2719

2820
use std::any::Any;
2921
use std::cmp::Eq;
22+
use std::collections::HashSet;
3023
use std::fmt::Debug;
3124
use std::hash::Hash;
3225
use std::sync::Arc;
3326

3427
use ahash::RandomState;
3528
use arrow::array::{Array, ArrayRef};
36-
use std::collections::HashSet;
29+
use arrow::datatypes::{DataType, Field, TimeUnit};
30+
use arrow_array::types::{
31+
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
32+
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
33+
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
34+
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
35+
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
36+
};
37+
use arrow_array::PrimitiveArray;
3738

38-
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
39-
use crate::expressions::format_state_name;
40-
use crate::{AggregateExpr, PhysicalExpr};
4139
use datafusion_common::cast::{as_list_array, as_primitive_array};
4240
use datafusion_common::utils::array_into_list_array;
4341
use datafusion_common::{Result, ScalarValue};
4442
use datafusion_expr::Accumulator;
4543

44+
use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
45+
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
46+
use crate::expressions::format_state_name;
47+
use crate::{AggregateExpr, PhysicalExpr};
48+
4649
type DistinctScalarValues = ScalarValue;
4750

4851
/// Expression for a COUNT(DISTINCT) aggregation.
@@ -61,10 +64,10 @@ impl DistinctCount {
6164
pub fn new(
6265
input_data_type: DataType,
6366
expr: Arc<dyn PhysicalExpr>,
64-
name: String,
67+
name: impl Into<String>,
6568
) -> Self {
6669
Self {
67-
name,
70+
name: name.into(),
6871
state_data_type: input_data_type,
6972
expr,
7073
}
@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
152155
Float32 => float_distinct_count_accumulator!(Float32Type),
153156
Float64 => float_distinct_count_accumulator!(Float64Type),
154157

158+
Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
159+
LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())),
160+
155161
_ => Ok(Box::new(DistinctCountAccumulator {
156162
values: HashSet::default(),
157163
state_data_type: self.state_data_type.clone(),
@@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator {
244250
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
245251
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
246252
for scalars in scalar_vec.into_iter() {
247-
self.values.extend(scalars)
253+
self.values.extend(scalars);
248254
}
249255
Ok(())
250256
}
@@ -440,9 +446,6 @@ where
440446

441447
#[cfg(test)]
442448
mod tests {
443-
use crate::expressions::NoOp;
444-
445-
use super::*;
446449
use arrow::array::{
447450
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
448451
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
@@ -454,10 +457,15 @@ mod tests {
454457
};
455458
use arrow_array::Decimal256Array;
456459
use arrow_buffer::i256;
460+
457461
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
458462
use datafusion_common::internal_err;
459463
use datafusion_common::DataFusionError;
460464

465+
use crate::expressions::NoOp;
466+
467+
use super::*;
468+
461469
macro_rules! state_to_vec_primitive {
462470
($LIST:expr, $DATA_TYPE:ident) => {{
463471
let arr = ScalarValue::raw_data($LIST).unwrap();

0 commit comments

Comments
 (0)