|  | 
|  | 1 | +// Licensed to the Apache Software Foundation (ASF) under one | 
|  | 2 | +// or more contributor license agreements.  See the NOTICE file | 
|  | 3 | +// distributed with this work for additional information | 
|  | 4 | +// regarding copyright ownership.  The ASF licenses this file | 
|  | 5 | +// to you under the Apache License, Version 2.0 (the | 
|  | 6 | +// "License"); you may not use this file except in compliance | 
|  | 7 | +// with the License.  You may obtain a copy of the License at | 
|  | 8 | +// | 
|  | 9 | +//   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 10 | +// | 
|  | 11 | +// Unless required by applicable law or agreed to in writing, | 
|  | 12 | +// software distributed under the License is distributed on an | 
|  | 13 | +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | 14 | +// KIND, either express or implied.  See the License for the | 
|  | 15 | +// specific language governing permissions and limitations | 
|  | 16 | +// under the License. | 
|  | 17 | + | 
|  | 18 | +//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet | 
|  | 19 | +
 | 
|  | 20 | +use std::sync::Arc; | 
|  | 21 | + | 
|  | 22 | +use arrow::array::ArrayRef; | 
|  | 23 | +use arrow::record_batch::RecordBatch; | 
|  | 24 | +use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array}; | 
|  | 25 | + | 
|  | 26 | +use arrow_array::cast::AsArray; | 
|  | 27 | +use datafusion::datasource::MemTable; | 
|  | 28 | +use rand::rngs::StdRng; | 
|  | 29 | +use rand::{thread_rng, Rng, SeedableRng}; | 
|  | 30 | +use std::collections::HashSet; | 
|  | 31 | +use tokio::task::JoinSet; | 
|  | 32 | + | 
|  | 33 | +use datafusion::prelude::{SessionConfig, SessionContext}; | 
|  | 34 | +use test_utils::stagger_batch; | 
|  | 35 | + | 
|  | 36 | +#[tokio::test(flavor = "multi_thread")] | 
|  | 37 | +async fn distinct_count_string_test() { | 
|  | 38 | +    // max length of generated strings | 
|  | 39 | +    let mut join_set = JoinSet::new(); | 
|  | 40 | +    let mut rng = thread_rng(); | 
|  | 41 | +    for null_pct in [0.0, 0.01, 0.1, 0.5] { | 
|  | 42 | +        for _ in 0..100 { | 
|  | 43 | +            let max_len = rng.gen_range(1..50); | 
|  | 44 | +            let num_strings = rng.gen_range(1..100); | 
|  | 45 | +            let num_distinct_strings = if num_strings > 1 { | 
|  | 46 | +                rng.gen_range(1..num_strings) | 
|  | 47 | +            } else { | 
|  | 48 | +                num_strings | 
|  | 49 | +            }; | 
|  | 50 | +            let generator = BatchGenerator { | 
|  | 51 | +                max_len, | 
|  | 52 | +                num_strings, | 
|  | 53 | +                num_distinct_strings, | 
|  | 54 | +                null_pct, | 
|  | 55 | +                rng: StdRng::from_seed(rng.gen()), | 
|  | 56 | +            }; | 
|  | 57 | +            join_set.spawn(async move { run_distinct_count_test(generator).await }); | 
|  | 58 | +        } | 
|  | 59 | +    } | 
|  | 60 | +    while let Some(join_handle) = join_set.join_next().await { | 
|  | 61 | +        // propagate errors | 
|  | 62 | +        join_handle.unwrap(); | 
|  | 63 | +    } | 
|  | 64 | +} | 
|  | 65 | + | 
|  | 66 | +/// Run COUNT DISTINCT using SQL and compare the result to computing the | 
|  | 67 | +/// distinct count using HashSet<String> | 
|  | 68 | +async fn run_distinct_count_test(mut generator: BatchGenerator) { | 
|  | 69 | +    let input = generator.make_input_batches(); | 
|  | 70 | + | 
|  | 71 | +    let schema = input[0].schema(); | 
|  | 72 | +    let session_config = SessionConfig::new().with_batch_size(50); | 
|  | 73 | +    let ctx = SessionContext::new_with_config(session_config); | 
|  | 74 | + | 
|  | 75 | +    // split input into two partitions | 
|  | 76 | +    let partition_len = input.len() / 2; | 
|  | 77 | +    let partitions = vec![ | 
|  | 78 | +        input[0..partition_len].to_vec(), | 
|  | 79 | +        input[partition_len..].to_vec(), | 
|  | 80 | +    ]; | 
|  | 81 | + | 
|  | 82 | +    let provider = MemTable::try_new(schema, partitions).unwrap(); | 
|  | 83 | +    ctx.register_table("t", Arc::new(provider)).unwrap(); | 
|  | 84 | +    // input has two columns, a and b.  The result is the number of distinct | 
|  | 85 | +    // values in each column. | 
|  | 86 | +    // | 
|  | 87 | +    // Note, we need  at least two count distinct aggregates to trigger the | 
|  | 88 | +    // count distinct aggregate. Otherwise, the optimizer will rewrite the | 
|  | 89 | +    // `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)` | 
|  | 90 | +    let results = ctx | 
|  | 91 | +        .sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t") | 
|  | 92 | +        .await | 
|  | 93 | +        .unwrap() | 
|  | 94 | +        .collect() | 
|  | 95 | +        .await | 
|  | 96 | +        .unwrap(); | 
|  | 97 | + | 
|  | 98 | +    // get all the strings from the first column of the result (distinct a) | 
|  | 99 | +    let expected_a = extract_distinct_strings::<i32>(&input, 0).len(); | 
|  | 100 | +    let result_a = extract_i64(&results, 0); | 
|  | 101 | +    assert_eq!(expected_a, result_a); | 
|  | 102 | + | 
|  | 103 | +    // get all the strings from the second column of the result (distinct b( | 
|  | 104 | +    let expected_b = extract_distinct_strings::<i64>(&input, 1).len(); | 
|  | 105 | +    let result_b = extract_i64(&results, 1); | 
|  | 106 | +    assert_eq!(expected_b, result_b); | 
|  | 107 | +} | 
|  | 108 | + | 
|  | 109 | +/// Return all (non null) distinct strings from  column col_idx | 
|  | 110 | +fn extract_distinct_strings<O: OffsetSizeTrait>( | 
|  | 111 | +    results: &[RecordBatch], | 
|  | 112 | +    col_idx: usize, | 
|  | 113 | +) -> Vec<String> { | 
|  | 114 | +    results | 
|  | 115 | +        .iter() | 
|  | 116 | +        .flat_map(|batch| { | 
|  | 117 | +            let array = batch.column(col_idx).as_string::<O>(); | 
|  | 118 | +            // remove nulls via 'flatten' | 
|  | 119 | +            array.iter().flatten().map(|s| s.to_string()) | 
|  | 120 | +        }) | 
|  | 121 | +        .collect::<HashSet<_>>() | 
|  | 122 | +        .into_iter() | 
|  | 123 | +        .collect() | 
|  | 124 | +} | 
|  | 125 | + | 
|  | 126 | +// extract the value from the Int64 column in col_idx in batch and return | 
|  | 127 | +// it as a usize | 
|  | 128 | +fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize { | 
|  | 129 | +    assert_eq!(results.len(), 1); | 
|  | 130 | +    let array = results[0] | 
|  | 131 | +        .column(col_idx) | 
|  | 132 | +        .as_any() | 
|  | 133 | +        .downcast_ref::<arrow::array::Int64Array>() | 
|  | 134 | +        .unwrap(); | 
|  | 135 | +    assert_eq!(array.len(), 1); | 
|  | 136 | +    assert!(!array.is_null(0)); | 
|  | 137 | +    array.value(0).try_into().unwrap() | 
|  | 138 | +} | 
|  | 139 | + | 
|  | 140 | +struct BatchGenerator { | 
|  | 141 | +    //// The maximum length of the strings | 
|  | 142 | +    max_len: usize, | 
|  | 143 | +    /// the total number of strings in the output | 
|  | 144 | +    num_strings: usize, | 
|  | 145 | +    /// The number of distinct strings in the columns | 
|  | 146 | +    num_distinct_strings: usize, | 
|  | 147 | +    /// The percentage of nulls in the columns | 
|  | 148 | +    null_pct: f64, | 
|  | 149 | +    /// Random number generator | 
|  | 150 | +    rng: StdRng, | 
|  | 151 | +} | 
|  | 152 | + | 
|  | 153 | +impl BatchGenerator { | 
|  | 154 | +    /// Make batches of random strings with a random length columns "a" and "b": | 
|  | 155 | +    /// | 
|  | 156 | +    /// * "a" is a StringArray | 
|  | 157 | +    /// * "b" is a LargeStringArray | 
|  | 158 | +    fn make_input_batches(&mut self) -> Vec<RecordBatch> { | 
|  | 159 | +        // use a random number generator to pick a random sized output | 
|  | 160 | + | 
|  | 161 | +        let batch = RecordBatch::try_from_iter(vec![ | 
|  | 162 | +            ("a", self.gen_data::<i32>()), | 
|  | 163 | +            ("b", self.gen_data::<i64>()), | 
|  | 164 | +        ]) | 
|  | 165 | +        .unwrap(); | 
|  | 166 | + | 
|  | 167 | +        stagger_batch(batch) | 
|  | 168 | +    } | 
|  | 169 | + | 
|  | 170 | +    /// Creates a StringArray or LargeStringArray with random strings according | 
|  | 171 | +    /// to the parameters of the BatchGenerator | 
|  | 172 | +    fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef { | 
|  | 173 | +        // table of strings from which to draw | 
|  | 174 | +        let distinct_strings: GenericStringArray<O> = (0..self.num_distinct_strings) | 
|  | 175 | +            .map(|_| Some(random_string(&mut self.rng, self.max_len))) | 
|  | 176 | +            .collect(); | 
|  | 177 | + | 
|  | 178 | +        // pick num_strings randomly from the distinct string table | 
|  | 179 | +        let indicies: UInt32Array = (0..self.num_strings) | 
|  | 180 | +            .map(|_| { | 
|  | 181 | +                if self.rng.gen::<f64>() < self.null_pct { | 
|  | 182 | +                    None | 
|  | 183 | +                } else if self.num_distinct_strings > 1 { | 
|  | 184 | +                    let range = 1..(self.num_distinct_strings as u32); | 
|  | 185 | +                    Some(self.rng.gen_range(range)) | 
|  | 186 | +                } else { | 
|  | 187 | +                    Some(0) | 
|  | 188 | +                } | 
|  | 189 | +            }) | 
|  | 190 | +            .collect(); | 
|  | 191 | + | 
|  | 192 | +        let options = None; | 
|  | 193 | +        arrow::compute::take(&distinct_strings, &indicies, options).unwrap() | 
|  | 194 | +    } | 
|  | 195 | +} | 
|  | 196 | + | 
|  | 197 | +/// Return a string of random characters of length 1..=max_len | 
|  | 198 | +fn random_string(rng: &mut StdRng, max_len: usize) -> String { | 
|  | 199 | +    // pick characters at random (not just ascii) | 
|  | 200 | +    match max_len { | 
|  | 201 | +        0 => "".to_string(), | 
|  | 202 | +        1 => String::from(rng.gen::<char>()), | 
|  | 203 | +        _ => { | 
|  | 204 | +            let len = rng.gen_range(1..=max_len); | 
|  | 205 | +            rng.sample_iter::<char, _>(rand::distributions::Standard) | 
|  | 206 | +                .take(len) | 
|  | 207 | +                .map(char::from) | 
|  | 208 | +                .collect::<String>() | 
|  | 209 | +        } | 
|  | 210 | +    } | 
|  | 211 | +} | 
0 commit comments