diff --git a/src/common/hashtable/src/lib.rs b/src/common/hashtable/src/lib.rs index ab08991dd2cfa..4eaec69a5c105 100644 --- a/src/common/hashtable/src/lib.rs +++ b/src/common/hashtable/src/lib.rs @@ -26,6 +26,7 @@ mod lookup_hashtable; mod stack_hashtable; mod table0; +mod simple_unsized_hashtable; #[allow(dead_code)] mod table1; mod table_empty; @@ -70,12 +71,21 @@ pub type UnsizedHashMap = unsized_hashtable::UnsizedHashtable; pub type UnsizedHashMapIter<'a, K, V> = unsized_hashtable::UnsizedHashtableIter<'a, K, V>; pub type UnsizedHashMapIterMut<'a, K, V> = unsized_hashtable::UnsizedHashtableIterMut<'a, K, V>; pub type UnsizedHashSet = unsized_hashtable::UnsizedHashtable; -pub type UnsizedHashSetIter<'a, K> = unsized_hashtable::UnsizedHashtableIter<'a, K, ()>; -pub type UnsizedHashSetIterMut<'a, K> = unsized_hashtable::UnsizedHashtableIterMut<'a, K, ()>; pub type UnsizedHashtableEntryRef<'a, K, V> = unsized_hashtable::UnsizedHashtableEntryRef<'a, K, V>; pub type UnsizedHashtableEntryMutRef<'a, K, V> = unsized_hashtable::UnsizedHashtableEntryMutRef<'a, K, V>; +pub type SimpleUnsizedHashMap = simple_unsized_hashtable::SimpleUnsizedHashtable; +pub type SimpleUnsizedHashMapIter<'a, K, V> = + simple_unsized_hashtable::SimpleUnsizedHashtableIter<'a, K, V>; +pub type SimpleUnsizedHashMapIterMut<'a, K, V> = + simple_unsized_hashtable::SimpleUnsizedHashtableIterMut<'a, K, V>; +pub type SimpleUnsizedHashSet = simple_unsized_hashtable::SimpleUnsizedHashtable; +pub type SimpleUnsizedHashtableEntryRef<'a, K, V> = + simple_unsized_hashtable::SimpleUnsizedHashtableEntryRef<'a, K, V>; +pub type SimpleUnsizedHashtableEntryMutRef<'a, K, V> = + simple_unsized_hashtable::SimpleUnsizedHashtableEntryMutRef<'a, K, V>; + pub type LookupHashMap = LookupHashtable; pub type LookupHashMapIter<'a, K, const CAPACITY: usize, V> = LookupTableIter<'a, CAPACITY, K, V>; pub type LookupHashMapIterMut<'a, K, const CAPACITY: usize, V> = diff --git a/src/common/hashtable/src/simple_unsized_hashtable.rs b/src/common/hashtable/src/simple_unsized_hashtable.rs new file mode 100644 index 0000000000000..c4eaaac482f8f --- /dev/null +++ b/src/common/hashtable/src/simple_unsized_hashtable.rs @@ -0,0 +1,617 @@ +// Copyright 2021 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::alloc::Allocator; +use std::marker::PhantomData; +use std::mem::MaybeUninit; + +use bumpalo::Bump; +use common_base::mem_allocator::GlobalAllocator; +use common_base::mem_allocator::MmapAllocator; + +use super::container::HeapContainer; +use super::table0::Entry; +use super::table0::Table0; +use super::traits::EntryMutRefLike; +use super::traits::EntryRefLike; +use super::traits::HashtableLike; +use super::traits::UnsizedKeyable; +use crate::table0::Table0Iter; +use crate::table0::Table0IterMut; +use crate::table_empty::TableEmpty; +use crate::table_empty::TableEmptyIter; +use crate::table_empty::TableEmptyIterMut; +use crate::unsized_hashtable::FallbackKey; + +/// Simple unsized hashtable is used for storing unsized keys in arena. It can be worked with HashMethodSerializer. +/// Different from `UnsizedHashtable`, it doesn't use adpative sub hashtable to store key values via key size. +/// It can be considered as a minimal hashtable implementation of UnsizedHashtable +pub struct SimpleUnsizedHashtable> +where + K: UnsizedKeyable + ?Sized, + A: Allocator + Clone, +{ + pub(crate) arena: Bump, + pub(crate) key_size: usize, + pub(crate) table_empty: TableEmpty, + pub(crate) table: Table0, A>, A>, + pub(crate) _phantom: PhantomData, +} + +unsafe impl Send + for SimpleUnsizedHashtable +{ +} + +unsafe impl Sync + for SimpleUnsizedHashtable +{ +} + +impl SimpleUnsizedHashtable +where + K: UnsizedKeyable + ?Sized, + A: Allocator + Clone + Default, +{ + pub fn new() -> Self { + Self::with_capacity(128) + } +} + +impl Default for SimpleUnsizedHashtable +where + K: UnsizedKeyable + ?Sized, + A: Allocator + Clone + Default, +{ + fn default() -> Self { + Self::new() + } +} + +impl SimpleUnsizedHashtable +where + K: UnsizedKeyable + ?Sized, + A: Allocator + Clone + Default, +{ + #[inline(always)] + pub fn set_insert(&mut self, key: &K) -> Result<&mut MaybeUninit<()>, &mut ()> { + unsafe { self.insert_borrowing(key) } + } + + #[inline(always)] + pub fn set_merge(&mut self, other: &Self) { + unsafe { + for _ in other.table_empty.iter() { + let _ = self.table_empty.insert(); + } + self.table.set_merge(&other.table); + } + } +} + +impl SimpleUnsizedHashtable +where + K: UnsizedKeyable + ?Sized, + A: Allocator + Clone + Default, +{ + /// The bump for strings doesn't allocate memory by `A`. + pub fn with_capacity(capacity: usize) -> Self { + let allocator = A::default(); + Self { + arena: Bump::new(), + key_size: 0, + table_empty: TableEmpty::new_in(allocator.clone()), + table: Table0::with_capacity_in(capacity, allocator), + _phantom: PhantomData, + } + } + + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.table_empty.len() + self.table.len() + } + + #[inline(always)] + pub fn capacity(&self) -> usize { + self.table_empty.capacity() + self.table.capacity() + } + + /// # Safety + /// + /// * The uninitialized value of returned entry should be written immediately. + /// * The lifetime of key lives longer than the hashtable. + #[inline(always)] + pub unsafe fn insert_and_entry_borrowing( + &mut self, + key: *const K, + ) -> Result< + SimpleUnsizedHashtableEntryMutRef<'_, K, V>, + SimpleUnsizedHashtableEntryMutRef<'_, K, V>, + > { + let key = (*key).as_bytes(); + match key.len() { + 0 => self + .table_empty + .insert() + .map(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }) + .map_err(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }), + _ => { + self.table.check_grow(); + self.table + .insert(FallbackKey::new(key)) + .map(|x| { + self.key_size += key.len(); + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(x), + ) + }) + .map_err(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(x), + ) + }) + } + } + } + /// # Safety + /// + /// * The uninitialized value of returned entry should be written immediately. + /// * The lifetime of key lives longer than the hashtable. + #[inline(always)] + pub unsafe fn insert_borrowing(&mut self, key: &K) -> Result<&mut MaybeUninit, &mut V> { + match self.insert_and_entry_borrowing(key) { + Ok(e) => Ok(&mut *(e.get_mut_ptr() as *mut MaybeUninit)), + Err(e) => Err(&mut *e.get_mut_ptr()), + } + } +} + +pub struct SimpleUnsizedHashtableIter<'a, K, V> +where K: UnsizedKeyable + ?Sized +{ + it_empty: Option>, + it: Option>, + _phantom: PhantomData<&'a mut K>, +} + +impl<'a, K, V> Iterator for SimpleUnsizedHashtableIter<'a, K, V> +where K: UnsizedKeyable + ?Sized +{ + type Item = SimpleUnsizedHashtableEntryRef<'a, K, V>; + + fn next(&mut self) -> Option { + if let Some(it) = self.it_empty.as_mut() { + if let Some(e) = it.next() { + return Some(SimpleUnsizedHashtableEntryRef( + SimpleUnsizedHashtableEntryRefInner::TableEmpty(e, PhantomData), + )); + } + self.it_empty = None; + } + if let Some(it) = self.it.as_mut() { + if let Some(e) = it.next() { + return Some(SimpleUnsizedHashtableEntryRef( + SimpleUnsizedHashtableEntryRefInner::Table(e), + )); + } + self.it = None; + } + None + } +} + +pub struct SimpleUnsizedHashtableIterMut<'a, K, V> +where K: UnsizedKeyable + ?Sized +{ + it_empty: Option>, + it: Option>, + _phantom: PhantomData<&'a mut K>, +} + +impl<'a, K, V> Iterator for SimpleUnsizedHashtableIterMut<'a, K, V> +where K: UnsizedKeyable + ?Sized +{ + type Item = SimpleUnsizedHashtableEntryMutRef<'a, K, V>; + + fn next(&mut self) -> Option { + if let Some(it) = self.it_empty.as_mut() { + if let Some(e) = it.next() { + return Some(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(e, PhantomData), + )); + } + self.it_empty = None; + } + + if let Some(it) = self.it.as_mut() { + if let Some(e) = it.next() { + return Some(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(e), + )); + } + self.it = None; + } + None + } +} + +enum SimpleUnsizedHashtableEntryRefInner<'a, K: ?Sized, V> { + TableEmpty(&'a Entry<[u8; 0], V>, PhantomData), + Table(&'a Entry), +} + +impl<'a, K: ?Sized, V> Copy for SimpleUnsizedHashtableEntryRefInner<'a, K, V> {} + +impl<'a, K: ?Sized, V> Clone for SimpleUnsizedHashtableEntryRefInner<'a, K, V> { + fn clone(&self) -> Self { + use SimpleUnsizedHashtableEntryRefInner::*; + match self { + TableEmpty(a, b) => TableEmpty(a, *b), + Table(a) => Table(a), + } + } +} + +impl<'a, K: ?Sized + UnsizedKeyable, V> SimpleUnsizedHashtableEntryRefInner<'a, K, V> { + fn key(self) -> &'a K { + use SimpleUnsizedHashtableEntryRefInner::*; + match self { + TableEmpty(_, _) => unsafe { UnsizedKeyable::from_bytes(&[]) }, + Table(e) => unsafe { + UnsizedKeyable::from_bytes(e.key.assume_init().key.unwrap().as_ref()) + }, + } + } + fn get(self) -> &'a V { + use SimpleUnsizedHashtableEntryRefInner::*; + match self { + TableEmpty(e, _) => e.get(), + Table(e) => e.get(), + } + } + fn get_ptr(self) -> *const V { + use SimpleUnsizedHashtableEntryRefInner::*; + match self { + TableEmpty(e, _) => e.val.as_ptr(), + Table(e) => e.val.as_ptr(), + } + } +} + +pub struct SimpleUnsizedHashtableEntryRef<'a, K: ?Sized, V>( + SimpleUnsizedHashtableEntryRefInner<'a, K, V>, +); + +impl<'a, K: ?Sized, V> Copy for SimpleUnsizedHashtableEntryRef<'a, K, V> {} + +impl<'a, K: ?Sized, V> Clone for SimpleUnsizedHashtableEntryRef<'a, K, V> { + fn clone(&self) -> Self { + Self(self.0) + } +} + +impl<'a, K: ?Sized + UnsizedKeyable, V> SimpleUnsizedHashtableEntryRef<'a, K, V> { + pub fn key(self) -> &'a K { + self.0.key() + } + pub fn get(self) -> &'a V { + self.0.get() + } + pub fn get_ptr(self) -> *const V { + self.0.get_ptr() + } +} + +enum SimpleUnsizedHashtableEntryMutRefInner<'a, K: ?Sized, V> { + TableEmpty(&'a mut Entry<[u8; 0], V>, PhantomData), + Table(&'a mut Entry), +} + +impl<'a, K: ?Sized + UnsizedKeyable, V> SimpleUnsizedHashtableEntryMutRefInner<'a, K, V> { + fn key(&self) -> &'a K { + use SimpleUnsizedHashtableEntryMutRefInner::*; + match self { + TableEmpty(_, _) => unsafe { &*(UnsizedKeyable::from_bytes(&[]) as *const K) }, + Table(e) => unsafe { + UnsizedKeyable::from_bytes(e.key.assume_init().key.unwrap().as_ref()) + }, + } + } + fn get(&self) -> &V { + use SimpleUnsizedHashtableEntryMutRefInner::*; + match self { + TableEmpty(e, _) => e.get(), + Table(e) => e.get(), + } + } + fn get_ptr(&self) -> *const V { + use SimpleUnsizedHashtableEntryMutRefInner::*; + match self { + TableEmpty(e, _) => e.val.as_ptr(), + Table(e) => e.val.as_ptr(), + } + } + fn get_mut(&mut self) -> &mut V { + use SimpleUnsizedHashtableEntryMutRefInner::*; + match self { + TableEmpty(e, _) => e.get_mut(), + Table(e) => e.get_mut(), + } + } + fn write(&mut self, val: V) { + use SimpleUnsizedHashtableEntryMutRefInner::*; + match self { + TableEmpty(e, _) => e.write(val), + Table(e) => e.write(val), + } + } +} + +pub struct SimpleUnsizedHashtableEntryMutRef<'a, K: ?Sized, V>( + SimpleUnsizedHashtableEntryMutRefInner<'a, K, V>, +); + +impl<'a, K: ?Sized + UnsizedKeyable, V> SimpleUnsizedHashtableEntryMutRef<'a, K, V> { + pub fn key(&self) -> &'a K { + self.0.key() + } + pub fn get(&self) -> &V { + self.0.get() + } + pub fn get_ptr(&self) -> *const V { + self.0.get_ptr() + } + pub fn get_mut_ptr(&self) -> *mut V { + self.get_ptr() as *mut V + } + pub fn get_mut(&mut self) -> &mut V { + self.0.get_mut() + } + pub fn write(&mut self, val: V) { + self.0.write(val) + } +} + +impl<'a, K: UnsizedKeyable + ?Sized + 'a, V: 'a> EntryRefLike + for SimpleUnsizedHashtableEntryRef<'a, K, V> +{ + type KeyRef = &'a K; + type ValueRef = &'a V; + + fn key(&self) -> Self::KeyRef { + (*self).key() + } + fn get(&self) -> Self::ValueRef { + (*self).get() + } +} + +impl<'a, K: UnsizedKeyable + ?Sized + 'a, V: 'a> EntryMutRefLike + for SimpleUnsizedHashtableEntryMutRef<'a, K, V> +{ + type Key = K; + type Value = V; + + fn key(&self) -> &Self::Key { + self.key() + } + + fn get(&self) -> &Self::Value { + self.get() + } + + fn get_mut(&mut self) -> &mut Self::Value { + self.get_mut() + } + + fn write(&mut self, value: Self::Value) { + self.write(value); + } +} + +impl HashtableLike for SimpleUnsizedHashtable<[u8], V, A> +where A: Allocator + Clone + Default +{ + type Key = [u8]; + type Value = V; + + type EntryRef<'a> = SimpleUnsizedHashtableEntryRef<'a, [u8], V> where Self: 'a, V: 'a; + type EntryMutRef<'a> = SimpleUnsizedHashtableEntryMutRef<'a, [u8], V> where Self: 'a, V: 'a; + + type Iterator<'a> = SimpleUnsizedHashtableIter<'a, [u8], V> where Self: 'a, V: 'a; + type IteratorMut<'a> = SimpleUnsizedHashtableIterMut<'a, [u8], V> where Self: 'a, V: 'a; + + fn len(&self) -> usize { + self.len() + } + + fn bytes_len(&self) -> usize { + std::mem::size_of::() + + self.arena.allocated_bytes() + + self.table_empty.heap_bytes() + + self.table.heap_bytes() + } + + fn unsize_key_size(&self) -> Option { + Some(self.key_size) + } + + fn entry(&self, key: &Self::Key) -> Option> { + let key = key.as_bytes(); + match key.len() { + 0 => self.table_empty.get().map(|x| { + SimpleUnsizedHashtableEntryRef(SimpleUnsizedHashtableEntryRefInner::TableEmpty( + x, + PhantomData, + )) + }), + _ => unsafe { + self.table.get(&FallbackKey::new(key)).map(|x| { + SimpleUnsizedHashtableEntryRef(SimpleUnsizedHashtableEntryRefInner::Table(x)) + }) + }, + } + } + + fn entry_mut(&mut self, key: &[u8]) -> Option> { + match key.len() { + 0 => self.table_empty.get_mut().map(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }), + _ => unsafe { + self.table.get_mut(&FallbackKey::new(key)).map(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(x), + ) + }) + }, + } + } + + fn get(&self, key: &Self::Key) -> Option<&Self::Value> { + self.entry(key).map(|e| e.get()) + } + + fn get_mut(&mut self, key: &Self::Key) -> Option<&mut Self::Value> { + self.entry_mut(key) + .map(|e| unsafe { &mut *(e.get_mut_ptr() as *mut V) }) + } + + unsafe fn insert( + &mut self, + key: &Self::Key, + ) -> Result<&mut MaybeUninit, &mut Self::Value> { + match self.insert_and_entry(key) { + Ok(e) => Ok(&mut *(e.get_mut_ptr() as *mut MaybeUninit)), + Err(e) => Err(&mut *e.get_mut_ptr()), + } + } + + #[inline(always)] + unsafe fn insert_and_entry( + &mut self, + key: &Self::Key, + ) -> Result, Self::EntryMutRef<'_>> { + let key = key.as_bytes(); + match key.len() { + 0 => self + .table_empty + .insert() + .map(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }) + .map_err(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }), + + _ => { + self.table.check_grow(); + match self.table.insert(FallbackKey::new(key)) { + Ok(e) => { + // We need to save the key to avoid drop it. + let s = self.arena.alloc_slice_copy(key); + e.set_key(FallbackKey::new_with_hash(s, e.key.assume_init_ref().hash)); + + self.key_size += key.len(); + Ok(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(e), + )) + } + Err(e) => Err(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(e), + )), + } + } + } + } + + #[inline(always)] + unsafe fn insert_and_entry_with_hash( + &mut self, + key: &Self::Key, + hash: u64, + ) -> Result, Self::EntryMutRef<'_>> { + let key = key.as_bytes(); + match key.len() { + 0 => self + .table_empty + .insert() + .map(|x| { + self.key_size += key.len(); + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }) + .map_err(|x| { + SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::TableEmpty(x, PhantomData), + ) + }), + _ => { + self.table.check_grow(); + match self + .table + .insert_with_hash(FallbackKey::new_with_hash(key, hash), hash) + { + Ok(e) => { + // We need to save the key to avoid drop it. + let s = self.arena.alloc_slice_copy(key); + e.set_key(FallbackKey::new_with_hash(s, hash)); + + self.key_size += key.len(); + Ok(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(e), + )) + } + Err(e) => Err(SimpleUnsizedHashtableEntryMutRef( + SimpleUnsizedHashtableEntryMutRefInner::Table(e), + )), + } + } + } + } + + fn iter(&self) -> Self::Iterator<'_> { + SimpleUnsizedHashtableIter { + it_empty: Some(self.table_empty.iter()), + it: Some(self.table.iter()), + _phantom: PhantomData, + } + } + + fn clear(&mut self) { + self.table_empty.clear(); + self.table.clear(); + drop(std::mem::take(&mut self.arena)); + } +} diff --git a/src/common/hashtable/src/unsized_hashtable.rs b/src/common/hashtable/src/unsized_hashtable.rs index 897611b678810..5fc79484d30cf 100644 --- a/src/common/hashtable/src/unsized_hashtable.rs +++ b/src/common/hashtable/src/unsized_hashtable.rs @@ -642,19 +642,19 @@ unsafe impl Keyable for InlineKey { #[derive(Copy, Clone)] pub(crate) struct FallbackKey { - key: Option>, - hash: u64, + pub(crate) key: Option>, + pub(crate) hash: u64, } impl FallbackKey { - unsafe fn new(key: &[u8]) -> Self { + pub(crate) unsafe fn new(key: &[u8]) -> Self { Self { key: Some(NonNull::from(key)), hash: key.fast_hash(), } } - unsafe fn new_with_hash(key: &[u8], hash: u64) -> Self { + pub(crate) unsafe fn new_with_hash(key: &[u8], hash: u64) -> Self { Self { hash, key: Some(NonNull::from(key)), diff --git a/src/query/expression/src/kernels/group_by.rs b/src/query/expression/src/kernels/group_by.rs index 11d47163c7d6b..1f6ef3ce9ec1c 100644 --- a/src/query/expression/src/kernels/group_by.rs +++ b/src/query/expression/src/kernels/group_by.rs @@ -20,6 +20,7 @@ use super::group_by_hash::HashMethodKeysU64; use super::group_by_hash::HashMethodKeysU8; use super::group_by_hash::HashMethodKind; use super::group_by_hash::HashMethodSerializer; +use super::group_by_hash::HashMethodSingleString; use crate::types::DataType; use crate::DataBlock; use crate::HashMethodKeysU128; @@ -41,6 +42,15 @@ impl DataBlock { } pub fn choose_hash_method_with_types(hash_key_types: &[DataType]) -> Result { + if hash_key_types.len() == 1 { + let typ = hash_key_types[0].clone(); + if matches!(typ, DataType::String | DataType::Variant) { + return Ok(HashMethodKind::SingleString( + HashMethodSingleString::default(), + )); + } + } + let mut group_key_len = 0; for typ in hash_key_types { let not_null_type = typ.remove_nullable(); diff --git a/src/query/expression/src/kernels/group_by_hash.rs b/src/query/expression/src/kernels/group_by_hash.rs index 4f909f3412430..1993faf371c06 100644 --- a/src/query/expression/src/kernels/group_by_hash.rs +++ b/src/query/expression/src/kernels/group_by_hash.rs @@ -81,6 +81,7 @@ pub type HashMethodKeysU512 = HashMethodFixedKeys; #[derive(Clone, Debug)] pub enum HashMethodKind { Serializer(HashMethodSerializer), + SingleString(HashMethodSingleString), KeysU8(HashMethodKeysU8), KeysU16(HashMethodKeysU16), KeysU32(HashMethodKeysU32), @@ -90,22 +91,48 @@ pub enum HashMethodKind { KeysU512(HashMethodKeysU512), } +#[macro_export] +macro_rules! with_hash_method { + ( | $t:tt | $($tail:tt)* ) => { + match_template::match_template! { + $t = [Serializer, SingleString, KeysU8, KeysU16, + KeysU32, KeysU64, KeysU128, KeysU256, KeysU512], + $($tail)* + } + } +} + +#[macro_export] +macro_rules! with_mappedhash_method { + ( | $t:tt | $($tail:tt)* ) => { + match_template::match_template! { + $t = [ + Serializer => HashMethodSerializer, + SingleString => HashMethodSingleString, + KeysU8 => HashMethodKeysU8, + KeysU16 => HashMethodKeysU16, + KeysU32 => HashMethodKeysU32, + KeysU64 => HashMethodKeysU64, + KeysU128 => HashMethodKeysU128, + KeysU256 => HashMethodKeysU256, + KeysU512 => HashMethodKeysU512 + ], + $($tail)* + } + } +} + impl HashMethodKind { pub fn name(&self) -> String { - match self { - HashMethodKind::Serializer(v) => v.name(), - HashMethodKind::KeysU8(v) => v.name(), - HashMethodKind::KeysU16(v) => v.name(), - HashMethodKind::KeysU32(v) => v.name(), - HashMethodKind::KeysU64(v) => v.name(), - HashMethodKind::KeysU128(v) => v.name(), - HashMethodKind::KeysU256(v) => v.name(), - HashMethodKind::KeysU512(v) => v.name(), - } + with_hash_method!(|T| match self { + HashMethodKind::T(v) => v.name(), + }) } + pub fn data_type(&self) -> DataType { match self { HashMethodKind::Serializer(_) => DataType::String, + HashMethodKind::SingleString(_) => DataType::String, HashMethodKind::KeysU8(_) => DataType::Number(NumberDataType::UInt8), HashMethodKind::KeysU16(_) => DataType::Number(NumberDataType::UInt16), HashMethodKind::KeysU32(_) => DataType::Number(NumberDataType::UInt32), @@ -117,6 +144,34 @@ impl HashMethodKind { } } +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct HashMethodSingleString {} + +impl HashMethod for HashMethodSingleString { + type HashKey = [u8]; + + type HashKeyIter<'a> = StringIterator<'a>; + + fn name(&self) -> String { + "SingleString".to_string() + } + + fn build_keys_state( + &self, + group_columns: &[(Column, DataType)], + _rows: usize, + ) -> Result { + Ok(KeysState::Column(group_columns[0].0.clone())) + } + + fn build_keys_iter<'a>(&self, key_state: &'a KeysState) -> Result> { + match key_state { + KeysState::Column(Column::String(col)) => Ok(col.iter()), + _ => unreachable!(), + } + } +} + #[derive(Debug, Clone, Default, PartialEq, Eq)] pub struct HashMethodSerializer {} @@ -134,10 +189,6 @@ impl HashMethod for HashMethodSerializer { group_columns: &[(Column, DataType)], rows: usize, ) -> Result { - if group_columns.len() == 1 && group_columns[0].1.is_string() { - return Ok(KeysState::Column(group_columns[0].0.clone())); - } - let approx_size = group_columns.len() * rows * 8; let mut builder = StringColumnBuilder::with_capacity(rows, approx_size); diff --git a/src/query/functions/src/scalars/comparison.rs b/src/query/functions/src/scalars/comparison.rs index 8df7e8f3c768d..717acd49f95c8 100644 --- a/src/query/functions/src/scalars/comparison.rs +++ b/src/query/functions/src/scalars/comparison.rs @@ -609,8 +609,17 @@ fn vectorize_like( } (ValueRef::Column(arg1), ValueRef::Scalar(arg2)) => { let arg1_iter = StringType::iter_column(&arg1); - let mut builder = MutableBitmap::with_capacity(arg1.len()); + let pattern_type = check_pattern_type(arg2, false); + // faster path for memmem to have a single instance of Finder + if pattern_type == PatternType::SurroundByPercent && arg2.len() > 2 { + let finder = memmem::Finder::new(&arg2[1..arg2.len() - 1]); + let it = arg1_iter.map(|arg1| finder.find(arg1).is_some()); + let bitmap = BooleanType::column_from_iter(it, &[]); + return Value::Column(bitmap); + } + + let mut builder = MutableBitmap::with_capacity(arg1.len()); for arg1 in arg1_iter { builder.push(func(arg1, arg2, ctx, pattern_type)); } diff --git a/src/query/service/src/pipelines/pipeline_builder.rs b/src/query/service/src/pipelines/pipeline_builder.rs index 8d5a3c3058d41..3fea8b1153bab 100644 --- a/src/query/service/src/pipelines/pipeline_builder.rs +++ b/src/query/service/src/pipelines/pipeline_builder.rs @@ -323,6 +323,7 @@ impl PipelineBuilder { // aggregate.output_schema()?, &aggregate.group_by, &aggregate.agg_funcs, + None, )?; self.main_pipeline.add_transform(|input, output| { @@ -343,6 +344,7 @@ impl PipelineBuilder { // aggregate.output_schema()?, &aggregate.group_by, &aggregate.agg_funcs, + aggregate.limit, )?; if self.ctx.get_cluster().is_empty() @@ -367,6 +369,7 @@ impl PipelineBuilder { input_schema: DataSchemaRef, group_by: &[IndexType], agg_funcs: &[AggregateFunctionDesc], + limit: Option, ) -> Result> { let mut agg_args = Vec::with_capacity(agg_funcs.len()); let (group_by, group_data_types) = group_by @@ -403,6 +406,7 @@ impl PipelineBuilder { &group_by, &aggs, &agg_args, + limit, )?; Ok(params) diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final.rs index cfc0677b0a862..ea0ed57e60edf 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final.rs @@ -11,32 +11,3 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -use common_expression::HashMethodKeysU128; -use common_expression::HashMethodKeysU16; -use common_expression::HashMethodKeysU256; -use common_expression::HashMethodKeysU32; -use common_expression::HashMethodKeysU512; -use common_expression::HashMethodKeysU64; -use common_expression::HashMethodKeysU8; -use common_expression::HashMethodSerializer; - -use crate::pipelines::processors::transforms::aggregator::aggregator_final_parallel::ParallelFinalAggregator; - -pub type KeysU8FinalAggregator = - ParallelFinalAggregator; -pub type KeysU16FinalAggregator = - ParallelFinalAggregator; -pub type KeysU32FinalAggregator = - ParallelFinalAggregator; -pub type KeysU64FinalAggregator = - ParallelFinalAggregator; -pub type KeysU128FinalAggregator = - ParallelFinalAggregator; -pub type KeysU256FinalAggregator = - ParallelFinalAggregator; -pub type KeysU512FinalAggregator = - ParallelFinalAggregator; - -pub type SerializerFinalAggregator = - ParallelFinalAggregator; diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final_parallel.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final_parallel.rs index 453673561a304..115f457811e6e 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final_parallel.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_final_parallel.rs @@ -26,7 +26,6 @@ use common_expression::ColumnBuilder; use common_expression::DataBlock; use common_expression::HashMethod; use common_functions::aggregates::StateAddr; -use common_functions::aggregates::StateAddrs; use common_hashtable::HashtableEntryMutRefLike; use common_hashtable::HashtableEntryRefLike; use common_hashtable::HashtableLike; @@ -144,6 +143,7 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static params: Arc, hash_table: Method::HashTable, + reach_limit: bool, // used for deserialization only, so we can reuse it during the loop temp_place: Option, } @@ -156,7 +156,7 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static let hash_table = method.create_hash_table()?; let temp_place = match params.aggregate_functions.is_empty() { true => None, - false => params.alloc_layout(&mut area), + false => Some(params.alloc_layout(&mut area)), }; Ok(Self { @@ -164,6 +164,7 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static method, params, hash_table, + reach_limit: false, temp_place, }) } @@ -188,6 +189,12 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static for key in keys_iter.iter() { let _ = self.hash_table.insert_and_entry(key); } + + if let Some(limit) = self.params.limit { + if self.hash_table.len() >= limit { + break; + } + } } } else { // first state places of current block @@ -212,13 +219,13 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static let aggregate_functions = &self.params.aggregate_functions; let offsets_aggregate_states = &self.params.offsets_aggregate_states; if let Some(temp_place) = self.temp_place { - for (row, place) in places.iter().enumerate() { + for (row, place) in places.iter() { for (idx, aggregate_function) in aggregate_functions.iter().enumerate() { let final_place = place.next(offsets_aggregate_states[idx]); let state_place = temp_place.next(offsets_aggregate_states[idx]); let mut data = - unsafe { states_binary_columns[idx].index_unchecked(row) }; + unsafe { states_binary_columns[idx].index_unchecked(*row) }; aggregate_function.deserialize(state_place, &mut data)?; aggregate_function.merge(final_place, state_place)?; } @@ -281,29 +288,49 @@ where Method: HashMethod + PolymorphicKeysHelper + Send + 'static let group_columns = group_columns_builder.finish()?; columns.extend_from_slice(&group_columns); + Ok(vec![DataBlock::new_from_columns(columns)]) } } /// Allocate aggregation function state for each key(the same key can always get the same state) #[inline(always)] - fn lookup_state(&mut self, keys_iter: &Method::KeysColumnIter) -> StateAddrs { + fn lookup_state(&mut self, keys_iter: &Method::KeysColumnIter) -> Vec<(usize, StateAddr)> { let iter = keys_iter.iter(); let (len, _) = iter.size_hint(); let mut places = Vec::with_capacity(len); + let mut current_len = self.hash_table.len(); unsafe { - for key in iter { + for (row, key) in iter.enumerate() { + if self.reach_limit { + let entry = self.hash_table.entry(key); + match entry { + Some(entry) => { + let place = Into::::into(*entry.get()); + places.push((row, place)); + } + None => continue, + } + } + match self.hash_table.insert_and_entry(key) { Ok(mut entry) => { - if let Some(place) = self.params.alloc_layout(&mut self.area) { - places.push(place); - *entry.get_mut() = place.addr(); + let place = self.params.alloc_layout(&mut self.area); + places.push((row, place)); + + *entry.get_mut() = place.addr(); + + if let Some(limit) = self.params.limit { + current_len += 1; + if current_len >= limit { + self.reach_limit = true; + } } } Err(entry) => { let place = Into::::into(*entry.get()); - places.push(place); + places.push((row, place)); } } } diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs index 9f3c41b7f30c1..98b8cf7bd3583 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_params.rs @@ -41,6 +41,9 @@ pub struct AggregatorParams { // If there is no aggregate function, layout is None pub layout: Option, pub offsets_aggregate_states: Vec, + + // Limit is push down to AggregatorTransform + pub limit: Option, } impl AggregatorParams { @@ -50,6 +53,7 @@ impl AggregatorParams { group_columns: &[usize], agg_funcs: &[AggregateFunctionRef], agg_args: &[Vec], + limit: Option, ) -> Result> { let mut states_offsets: Vec = Vec::with_capacity(agg_funcs.len()); let mut states_layout = None; @@ -66,10 +70,11 @@ impl AggregatorParams { aggregate_functions_arguments: agg_args.to_vec(), layout: states_layout, offsets_aggregate_states: states_offsets, + limit, })) } - pub fn alloc_layout(&self, area: &mut Area) -> Option { + pub fn alloc_layout(&self, area: &mut Area) -> StateAddr { let layout = self.layout.unwrap(); let place = Into::::into(area.alloc_layout(layout)); @@ -78,7 +83,7 @@ impl AggregatorParams { let aggr_state_place = place.next(aggr_state); self.aggregate_functions[idx].init_state(aggr_state_place); } - Some(place) + place } pub fn has_distinct_combinator(&self) -> bool { diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_partial.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_partial.rs index 6fbaacb50d006..6b877f9b4ba94 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_partial.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/aggregator_partial.rs @@ -20,14 +20,6 @@ use common_expression::BlockEntry; use common_expression::Column; use common_expression::DataBlock; use common_expression::HashMethod; -use common_expression::HashMethodKeysU128; -use common_expression::HashMethodKeysU16; -use common_expression::HashMethodKeysU256; -use common_expression::HashMethodKeysU32; -use common_expression::HashMethodKeysU512; -use common_expression::HashMethodKeysU64; -use common_expression::HashMethodKeysU8; -use common_expression::HashMethodSerializer; use common_functions::aggregates::StateAddr; use common_functions::aggregates::StateAddrs; use common_hashtable::HashtableEntryMutRefLike; @@ -41,24 +33,6 @@ use crate::pipelines::processors::transforms::group_by::PolymorphicKeysHelper; use crate::pipelines::processors::transforms::transform_aggregator::Aggregator; use crate::pipelines::processors::AggregatorParams; -pub type Keys8Grouper = PartialAggregator; -pub type Keys16Grouper = PartialAggregator; -pub type Keys32Grouper = PartialAggregator; -pub type Keys64Grouper = PartialAggregator; -pub type Keys128Grouper = PartialAggregator; -pub type Keys256Grouper = PartialAggregator; -pub type Keys512Grouper = PartialAggregator; -pub type KeysSerializerGrouper = PartialAggregator; - -pub type Keys8Aggregator = PartialAggregator; -pub type Keys16Aggregator = PartialAggregator; -pub type Keys32Aggregator = PartialAggregator; -pub type Keys64Aggregator = PartialAggregator; -pub type Keys128Aggregator = PartialAggregator; -pub type Keys256Aggregator = PartialAggregator; -pub type Keys512Aggregator = PartialAggregator; -pub type KeysSerializerAggregator = PartialAggregator; - pub struct PartialAggregator where Method: HashMethod + PolymorphicKeysHelper { @@ -107,10 +81,9 @@ impl + S for key in keys_iter { match hashtable.insert_and_entry(key) { Ok(mut entry) => { - if let Some(place) = params.alloc_layout(area) { - places.push(place); - *entry.get_mut() = place.addr(); - } + let place = params.alloc_layout(area); + places.push(place); + *entry.get_mut() = place.addr(); } Err(entry) => { let place = Into::::into(*entry.get()); diff --git a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs index af10beaca56f7..0a74b193b18df 100644 --- a/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs +++ b/src/query/service/src/pipelines/processors/transforms/aggregator/mod.rs @@ -23,33 +23,10 @@ mod utils; pub use aggregate_info::AggregateInfo; pub use aggregate_info::OverflowInfo; -pub use aggregator_final::KeysU128FinalAggregator; -pub use aggregator_final::KeysU16FinalAggregator; -pub use aggregator_final::KeysU256FinalAggregator; -pub use aggregator_final::KeysU32FinalAggregator; -pub use aggregator_final::KeysU512FinalAggregator; -pub use aggregator_final::KeysU64FinalAggregator; -pub use aggregator_final::KeysU8FinalAggregator; -pub use aggregator_final::SerializerFinalAggregator; pub use aggregator_final_parallel::BucketAggregator; +pub use aggregator_final_parallel::ParallelFinalAggregator; pub use aggregator_params::AggregatorParams; pub use aggregator_params::AggregatorTransformParams; -pub use aggregator_partial::Keys128Aggregator; -pub use aggregator_partial::Keys128Grouper; -pub use aggregator_partial::Keys16Aggregator; -pub use aggregator_partial::Keys16Grouper; -pub use aggregator_partial::Keys256Aggregator; -pub use aggregator_partial::Keys256Grouper; -pub use aggregator_partial::Keys32Aggregator; -pub use aggregator_partial::Keys32Grouper; -pub use aggregator_partial::Keys512Aggregator; -pub use aggregator_partial::Keys512Grouper; -pub use aggregator_partial::Keys64Aggregator; -pub use aggregator_partial::Keys64Grouper; -pub use aggregator_partial::Keys8Aggregator; -pub use aggregator_partial::Keys8Grouper; -pub use aggregator_partial::KeysSerializerAggregator; -pub use aggregator_partial::KeysSerializerGrouper; pub use aggregator_partial::PartialAggregator; pub use aggregator_single_key::FinalSingleStateAggregator; pub use aggregator_single_key::PartialSingleStateAggregator; diff --git a/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs b/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs index 41af91c692c7c..f51a6acf6232e 100644 --- a/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs +++ b/src/query/service/src/pipelines/processors/transforms/group_by/aggregator_polymorphic_keys.rs @@ -27,11 +27,13 @@ use common_expression::HashMethodKeysU128; use common_expression::HashMethodKeysU256; use common_expression::HashMethodKeysU512; use common_expression::HashMethodSerializer; +use common_expression::HashMethodSingleString; use common_expression::KeysState; use common_hashtable::FastHash; use common_hashtable::HashMap; use common_hashtable::HashtableLike; use common_hashtable::LookupHashMap; +use common_hashtable::SimpleUnsizedHashMap; use common_hashtable::TwoLevelHashMap; use common_hashtable::UnsizedHashMap; use primitive_types::U256; @@ -392,7 +394,7 @@ impl PolymorphicKeysHelper for HashMethodKeysU512 { } } -impl PolymorphicKeysHelper for HashMethodSerializer { +impl PolymorphicKeysHelper for HashMethodSingleString { const SUPPORT_TWO_LEVEL: bool = true; type HashTable = UnsizedHashMap<[u8], usize>; @@ -432,6 +434,46 @@ impl PolymorphicKeysHelper for HashMethodSerializer { } } +impl PolymorphicKeysHelper for HashMethodSerializer { + const SUPPORT_TWO_LEVEL: bool = true; + + type HashTable = SimpleUnsizedHashMap<[u8], usize>; + + fn create_hash_table(&self) -> Result { + Ok(SimpleUnsizedHashMap::new()) + } + + type ColumnBuilder<'a> = SerializedKeysColumnBuilder<'a>; + fn keys_column_builder( + &self, + capacity: usize, + value_capacity: usize, + ) -> SerializedKeysColumnBuilder<'_> { + SerializedKeysColumnBuilder::create(capacity, value_capacity) + } + + type KeysColumnIter = SerializedKeysColumnIter; + fn keys_iter_from_column(&self, column: &Column) -> Result { + SerializedKeysColumnIter::create(column.as_string().ok_or_else(|| { + ErrorCode::IllegalDataType("Illegal data type for SerializedKeysColumnIter".to_string()) + })?) + } + + type GroupColumnsBuilder<'a> = SerializedKeysGroupColumnsBuilder<'a>; + fn group_columns_builder( + &self, + capacity: usize, + data_capacity: usize, + params: &AggregatorParams, + ) -> SerializedKeysGroupColumnsBuilder<'_> { + SerializedKeysGroupColumnsBuilder::create(capacity, data_capacity, params) + } + + fn get_hash(&self, v: &[u8]) -> u64 { + v.fast_hash() + } +} + #[derive(Clone)] pub struct TwoLevelHashMethod { method: Method, diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_state_impl.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_state_impl.rs index 0753849193106..11c3e408377f7 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_state_impl.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/hash_join_state_impl.rs @@ -21,6 +21,7 @@ use common_exception::ErrorCode; use common_exception::Result; use common_expression::DataBlock; use common_expression::HashMethod; +use common_hashtable::HashtableLike; use super::ProbeState; use crate::pipelines::processors::transforms::hash_join::desc::MarkerKind; @@ -122,6 +123,34 @@ impl HashJoinState for JoinHashTable { } }}; } + + macro_rules! insert_string_key { + ($table: expr, $markers: expr, $method: expr, $chunk: expr, $columns: expr, $chunk_index: expr, ) => {{ + let keys_state = $method.build_keys_state(&$columns, $chunk.num_rows())?; + let build_keys_iter = $method.build_keys_iter(&keys_state)?; + + for (row_index, key) in build_keys_iter.enumerate().take($chunk.num_rows()) { + let ptr = RowPtr { + chunk_index: $chunk_index, + row_index, + marker: $markers[row_index], + }; + if self.hash_join_desc.join_type == JoinType::LeftMark { + let mut self_row_ptrs = self.row_ptrs.write(); + self_row_ptrs.push(ptr); + } + match unsafe { $table.insert(key) } { + Ok(entity) => { + entity.write(vec![ptr]); + } + Err(entity) => { + entity.push(ptr); + } + } + } + }}; + } + { let buffer = self.row_space.buffer.write().unwrap(); if !buffer.is_empty() { @@ -163,94 +192,34 @@ impl HashJoinState for JoinHashTable { vec![None; chunk.num_rows()] } }; + match (*self.hash_table.write()).borrow_mut() { - HashTable::SerializerHashTable(table) => { - let mut build_cols_ref = Vec::with_capacity(chunk.cols.len()); - for build_col in chunk.cols.iter() { - build_cols_ref.push(build_col.clone()); - } - let keys_state = table - .hash_method - .build_keys_state(&build_cols_ref, chunk.num_rows())?; - chunk.keys_state = Some(keys_state); - let build_keys_iter = table - .hash_method - .build_keys_iter(chunk.keys_state.as_ref().unwrap())?; - for (row_index, key) in build_keys_iter.enumerate().take(chunk.num_rows()) { - let ptr = RowPtr { - chunk_index, - row_index, - marker: markers[row_index], - }; - if self.hash_join_desc.join_type == JoinType::LeftMark { - let mut self_row_ptrs = self.row_ptrs.write(); - self_row_ptrs.push(ptr); - } - match unsafe { table.hash_table.insert_borrowing(key) } { - Ok(entity) => { - entity.write(vec![ptr]); - } - Err(entity) => { - entity.push(ptr); - } - } - } - } - HashTable::KeyU8HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::Serializer(table) => insert_string_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, + }, + HashTable::SingleString(table) => insert_string_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, + }, + HashTable::KeysU8(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU16HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU16(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU32HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU32(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU64HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU64(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU128HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU128(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU256HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU256(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, - HashTable::KeyU512HashTable(table) => insert_key! { - &mut table.hash_table, - &markers, - &table.hash_method, - chunk, - columns, - chunk_index, + HashTable::KeysU512(table) => insert_key! { + &mut table.hash_table, &markers, &table.hash_method,chunk,columns,chunk_index, }, } } diff --git a/src/query/service/src/pipelines/processors/transforms/hash_join/join_hash_table.rs b/src/query/service/src/pipelines/processors/transforms/hash_join/join_hash_table.rs index fa1f8bd4b8b56..8289a825db336 100644 --- a/src/query/service/src/pipelines/processors/transforms/hash_join/join_hash_table.rs +++ b/src/query/service/src/pipelines/processors/transforms/hash_join/join_hash_table.rs @@ -22,6 +22,7 @@ use common_base::base::tokio::sync::Notify; use common_exception::ErrorCode; use common_exception::Result; use common_expression::arrow::and_validities; +use common_expression::with_hash_method; use common_expression::DataBlock; use common_expression::DataSchemaRef; use common_expression::Evaluator; @@ -29,10 +30,12 @@ use common_expression::HashMethod; use common_expression::HashMethodFixedKeys; use common_expression::HashMethodKind; use common_expression::HashMethodSerializer; +use common_expression::HashMethodSingleString; use common_expression::RemoteExpr; use common_functions::scalars::BUILTIN_FUNCTIONS; use common_hashtable::HashMap; use common_hashtable::HashtableKeyable; +use common_hashtable::SimpleUnsizedHashMap; use common_hashtable::UnsizedHashMap; use common_sql::plans::JoinType; use parking_lot::RwLock; @@ -49,24 +52,30 @@ use crate::sessions::QueryContext; use crate::sessions::TableContext; pub struct SerializerHashTable { - pub(crate) hash_table: UnsizedHashMap<[u8], Vec>, + pub(crate) hash_table: SimpleUnsizedHashMap<[u8], Vec>, pub(crate) hash_method: HashMethodSerializer, } +pub struct SingleStringHashTable { + pub(crate) hash_table: UnsizedHashMap<[u8], Vec>, + pub(crate) hash_method: HashMethodSingleString, +} + pub struct FixedKeyHashTable { pub(crate) hash_table: HashMap>, pub(crate) hash_method: HashMethodFixedKeys, } pub enum HashTable { - SerializerHashTable(SerializerHashTable), - KeyU8HashTable(FixedKeyHashTable), - KeyU16HashTable(FixedKeyHashTable), - KeyU32HashTable(FixedKeyHashTable), - KeyU64HashTable(FixedKeyHashTable), - KeyU128HashTable(FixedKeyHashTable), - KeyU256HashTable(FixedKeyHashTable), - KeyU512HashTable(FixedKeyHashTable), + Serializer(SerializerHashTable), + SingleString(SingleStringHashTable), + KeysU8(FixedKeyHashTable), + KeysU16(FixedKeyHashTable), + KeysU32(FixedKeyHashTable), + KeysU64(FixedKeyHashTable), + KeysU128(FixedKeyHashTable), + KeysU256(FixedKeyHashTable), + KeysU512(FixedKeyHashTable), } pub struct JoinHashTable { @@ -100,17 +109,27 @@ impl JoinHashTable { Ok(match method { HashMethodKind::Serializer(_) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::SerializerHashTable(SerializerHashTable { - hash_table: UnsizedHashMap::<[u8], Vec>::new(), + HashTable::Serializer(SerializerHashTable { + hash_table: SimpleUnsizedHashMap::<[u8], Vec>::new(), hash_method: HashMethodSerializer::default(), }), build_schema, probe_schema, hash_join_desc, )?), + HashMethodKind::SingleString(_) => Arc::new(JoinHashTable::try_create( + ctx, + HashTable::SingleString(SingleStringHashTable { + hash_table: UnsizedHashMap::<[u8], Vec>::new(), + hash_method: HashMethodSingleString::default(), + }), + build_schema, + probe_schema, + hash_join_desc, + )?), HashMethodKind::KeysU8(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU8HashTable(FixedKeyHashTable { + HashTable::KeysU8(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -120,7 +139,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU16(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU16HashTable(FixedKeyHashTable { + HashTable::KeysU16(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -130,7 +149,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU32(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU32HashTable(FixedKeyHashTable { + HashTable::KeysU32(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -140,7 +159,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU64(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU64HashTable(FixedKeyHashTable { + HashTable::KeysU64(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -150,7 +169,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU128(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU128HashTable(FixedKeyHashTable { + HashTable::KeysU128(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -160,7 +179,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU256(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU256HashTable(FixedKeyHashTable { + HashTable::KeysU256(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -170,7 +189,7 @@ impl JoinHashTable { )?), HashMethodKind::KeysU512(hash_method) => Arc::new(JoinHashTable::try_create( ctx, - HashTable::KeyU512HashTable(FixedKeyHashTable { + HashTable::KeysU512(FixedKeyHashTable { hash_table: HashMap::>::new(), hash_method, }), @@ -280,64 +299,14 @@ impl JoinHashTable { probe_state.valids = valids; } - match &*self.hash_table.read() { - HashTable::SerializerHashTable(table) => { + with_hash_method!(|T| match &*self.hash_table.read() { + HashTable::T(table) => { let keys_state = table .hash_method .build_keys_state(&probe_keys, input.num_rows())?; let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) } - HashTable::KeyU8HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU16HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU32HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU64HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU128HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU256HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - HashTable::KeyU512HashTable(table) => { - let keys_state = table - .hash_method - .build_keys_state(&probe_keys, input.num_rows())?; - let keys_iter = table.hash_method.build_keys_iter(&keys_state)?; - self.result_blocks(&table.hash_table, probe_state, keys_iter, &input) - } - } + }) } } diff --git a/src/query/service/src/pipelines/processors/transforms/transform_aggregator.rs b/src/query/service/src/pipelines/processors/transforms/transform_aggregator.rs index 0a58203bf949f..69cf595b7ba61 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_aggregator.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_aggregator.rs @@ -18,8 +18,7 @@ use std::sync::Arc; use common_catalog::table_context::TableContext; use common_exception::ErrorCode; use common_exception::Result; -use common_expression::DataBlock; -use common_expression::HashMethodKind; +use common_expression::*; use crate::pipelines::processors::port::InputPort; use crate::pipelines::processors::port::OutputPort; @@ -49,90 +48,21 @@ impl TransformAggregator { } match aggregator_params.aggregate_functions.is_empty() { - true => match transform_params.method.clone() { - HashMethodKind::KeysU8(method) => AggregatorTransform::create( + true => with_mappedhash_method!(|T| match transform_params.method.clone() { + HashMethodKind::T(method) => AggregatorTransform::create( ctx.clone(), transform_params, - KeysU8FinalAggregator::::create(ctx, method, aggregator_params)?, + ParallelFinalAggregator::::create(ctx, method, aggregator_params)?, ), - HashMethodKind::KeysU16(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU16FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU32(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU32FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU64(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU64FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::Serializer(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - SerializerFinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU128(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU128FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU256(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU256FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU512(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU512FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - }, - false => match transform_params.method.clone() { - HashMethodKind::KeysU8(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU8FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU16(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU16FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU32(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU32FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU64(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU64FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::Serializer(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - SerializerFinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU128(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU128FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU256(method) => AggregatorTransform::create( - ctx.clone(), - transform_params, - KeysU256FinalAggregator::::create(ctx, method, aggregator_params)?, - ), - HashMethodKind::KeysU512(method) => AggregatorTransform::create( + }), + + false => with_mappedhash_method!(|T| match transform_params.method.clone() { + HashMethodKind::T(method) => AggregatorTransform::create( ctx.clone(), transform_params, - KeysU512FinalAggregator::::create(ctx, method, aggregator_params)?, + ParallelFinalAggregator::::create(ctx, method, aggregator_params)?, ), - }, + }), } } @@ -152,90 +82,20 @@ impl TransformAggregator { } match aggregator_params.aggregate_functions.is_empty() { - true => match transform_params.method.clone() { - HashMethodKind::KeysU8(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys8Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU16(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys16Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU32(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys32Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU64(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys64Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU128(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys128Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU256(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys256Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU512(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys512Grouper::create(method, aggregator_params)?, - ), - HashMethodKind::Serializer(method) => AggregatorTransform::create( - ctx, - transform_params, - KeysSerializerGrouper::create(method, aggregator_params)?, - ), - }, - false => match transform_params.method.clone() { - HashMethodKind::KeysU8(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys8Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU16(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys16Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU32(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys32Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU64(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys64Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU128(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys128Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU256(method) => AggregatorTransform::create( - ctx, - transform_params, - Keys256Aggregator::create(method, aggregator_params)?, - ), - HashMethodKind::KeysU512(method) => AggregatorTransform::create( + true => with_mappedhash_method!(|T| match transform_params.method.clone() { + HashMethodKind::T(method) => AggregatorTransform::create( ctx, transform_params, - Keys512Aggregator::create(method, aggregator_params)?, + PartialAggregator::::create(method, aggregator_params)?, ), - HashMethodKind::Serializer(method) => AggregatorTransform::create( + }), + false => with_mappedhash_method!(|T| match transform_params.method.clone() { + HashMethodKind::T(method) => AggregatorTransform::create( ctx, transform_params, - KeysSerializerAggregator::create(method, aggregator_params)?, + PartialAggregator::::create(method, aggregator_params)?, ), - }, + }), } } } diff --git a/src/query/service/src/pipelines/processors/transforms/transform_convert_grouping.rs b/src/query/service/src/pipelines/processors/transforms/transform_convert_grouping.rs index d0f5b11e10b89..2d9e1eb1947fc 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_convert_grouping.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_convert_grouping.rs @@ -18,6 +18,7 @@ use std::collections::HashMap; use std::sync::Arc; use common_exception::Result; +use common_expression::with_hash_method; use common_expression::BlockMetaInfo; use common_expression::BlockMetaInfoPtr; use common_expression::DataBlock; @@ -450,16 +451,9 @@ pub fn efficiently_memory_final_aggregator( let sample_block = DataBlock::empty_with_schema(schema_before_group_by); let method = DataBlock::choose_hash_method(&sample_block, group_cols)?; - match method { - HashMethodKind::KeysU8(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU16(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU32(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU64(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU128(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU256(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::KeysU512(v) => build_convert_grouping(v, pipeline, params.clone()), - HashMethodKind::Serializer(v) => build_convert_grouping(v, pipeline, params.clone()), - } + with_hash_method!(|T| match method { + HashMethodKind::T(v) => build_convert_grouping(v, pipeline, params.clone()), + }) } struct MergeBucketTransform + Send + 'static> { diff --git a/src/query/sql/src/executor/format.rs b/src/query/sql/src/executor/format.rs index 06efc3a965138..c92be01e2c107 100644 --- a/src/query/sql/src/executor/format.rs +++ b/src/query/sql/src/executor/format.rs @@ -318,6 +318,11 @@ fn aggregate_final_to_format_tree( FormatTreeNode::new(format!("aggregate functions: [{agg_funcs}]")), ]; + if let Some(limit) = &plan.limit { + let items = FormatTreeNode::new(format!("limit: {limit}")); + children.push(items); + } + if let Some(info) = &plan.stat_info { let items = plan_stats_info_to_format_tree(info); children.extend(items); diff --git a/src/query/sql/src/executor/physical_plan.rs b/src/query/sql/src/executor/physical_plan.rs index ddb309ff30e05..2620aab868486 100644 --- a/src/query/sql/src/executor/physical_plan.rs +++ b/src/query/sql/src/executor/physical_plan.rs @@ -120,7 +120,6 @@ pub struct AggregatePartial { pub input: Box, pub group_by: Vec, pub agg_funcs: Vec, - /// Only used for explain pub stat_info: Option, } @@ -161,6 +160,7 @@ pub struct AggregateFinal { pub agg_funcs: Vec, pub before_group_by_schema: DataSchemaRef, + pub limit: Option, /// Only used for explain pub stat_info: Option, } diff --git a/src/query/sql/src/executor/physical_plan_builder.rs b/src/query/sql/src/executor/physical_plan_builder.rs index ab94054875769..b008785d7db65 100644 --- a/src/query/sql/src/executor/physical_plan_builder.rs +++ b/src/query/sql/src/executor/physical_plan_builder.rs @@ -306,6 +306,7 @@ impl PhysicalPlanBuilder { let input = self.build(s_expr.child(0)?).await?; let input_schema = input.output_schema()?; let group_items = agg.group_items.iter().map(|v| v.index).collect::>(); + let result = match &agg.mode { AggregateMode::Partial => { let agg_funcs: Vec = agg.aggregate_functions.iter().map(|v| { @@ -443,8 +444,9 @@ impl PhysicalPlanBuilder { }).collect::>()?; match input { - PhysicalPlan::AggregatePartial(ref agg) => { - let before_group_by_schema = agg.input.output_schema()?; + PhysicalPlan::AggregatePartial(ref partial) => { + let before_group_by_schema = partial.input.output_schema()?; + let limit = agg.limit; PhysicalPlan::AggregateFinal(AggregateFinal { input: Box::new(input), group_by: group_items, @@ -452,14 +454,17 @@ impl PhysicalPlanBuilder { before_group_by_schema, stat_info: Some(stat_info), + limit, }) } PhysicalPlan::Exchange(PhysicalExchange { - input: box PhysicalPlan::AggregatePartial(ref agg), + input: box PhysicalPlan::AggregatePartial(ref partial), .. }) => { - let before_group_by_schema = agg.input.output_schema()?; + let before_group_by_schema = partial.input.output_schema()?; + let limit = agg.limit; + PhysicalPlan::AggregateFinal(AggregateFinal { input: Box::new(input), group_by: group_items, @@ -467,6 +472,7 @@ impl PhysicalPlanBuilder { before_group_by_schema, stat_info: Some(stat_info), + limit, }) } diff --git a/src/query/sql/src/executor/physical_plan_visitor.rs b/src/query/sql/src/executor/physical_plan_visitor.rs index b8d1b1ffc4de2..a13ae88855c07 100644 --- a/src/query/sql/src/executor/physical_plan_visitor.rs +++ b/src/query/sql/src/executor/physical_plan_visitor.rs @@ -105,6 +105,7 @@ pub trait PhysicalPlanReplacer { group_by: plan.group_by.clone(), agg_funcs: plan.agg_funcs.clone(), stat_info: plan.stat_info.clone(), + limit: plan.limit, })) } diff --git a/src/query/sql/src/planner/binder/aggregate.rs b/src/query/sql/src/planner/binder/aggregate.rs index fb7db658b9fca..4adc587336167 100644 --- a/src/query/sql/src/planner/binder/aggregate.rs +++ b/src/query/sql/src/planner/binder/aggregate.rs @@ -306,6 +306,7 @@ impl<'a> Binder { group_items: bind_context.aggregate_info.group_items.clone(), aggregate_functions: bind_context.aggregate_info.aggregate_functions.clone(), from_distinct: false, + limit: None, }; new_expr = SExpr::create_unary(aggregate_plan.into(), new_expr); diff --git a/src/query/sql/src/planner/binder/distinct.rs b/src/query/sql/src/planner/binder/distinct.rs index 9ac18d36336d8..315be7a465549 100644 --- a/src/query/sql/src/planner/binder/distinct.rs +++ b/src/query/sql/src/planner/binder/distinct.rs @@ -75,6 +75,7 @@ impl Binder { group_items, aggregate_functions: vec![], from_distinct: true, + limit: None, }; Ok(SExpr::create_unary(distinct_plan.into(), new_expr)) diff --git a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs index 77d23aa5cc659..df23b03d5050e 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/decorrelate.rs @@ -604,6 +604,7 @@ impl SubqueryRewriter { group_items, aggregate_functions: agg_items, from_distinct: aggregate.from_distinct, + limit: aggregate.limit, } .into(), flatten_plan, diff --git a/src/query/sql/src/planner/optimizer/heuristic/heuristic.rs b/src/query/sql/src/planner/optimizer/heuristic/heuristic.rs index 0d6f9b72af05a..b2d960421db77 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/heuristic.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/heuristic.rs @@ -41,6 +41,7 @@ pub static DEFAULT_REWRITE_RULES: Lazy> = Lazy::new(|| { RuleID::PushDownLimitUnion, RuleID::RulePushDownLimitExpression, RuleID::PushDownLimitSort, + RuleID::PushDownLimitAggregate, RuleID::PushDownLimitOuterJoin, RuleID::PushDownLimitScan, RuleID::PushDownSortScan, @@ -97,6 +98,7 @@ impl HeuristicOptimizer { let pre_optimized = self.pre_optimize(s_expr)?; let optimized = self.optimize_expression(&pre_optimized)?; let post_optimized = self.post_optimize(optimized)?; + Ok(post_optimized) } @@ -129,7 +131,6 @@ impl HeuristicOptimizer { } } } - Ok(s_expr.clone()) } } diff --git a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs index faf4fec83ee59..1792425133c6e 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/prune_unused_columns.rs @@ -185,6 +185,7 @@ impl UnusedColumnPruner { aggregate_functions: used, from_distinct: p.from_distinct, mode: p.mode, + limit: p.limit, }), Self::keep_required_columns(expr.child(0)?, required)?, )) diff --git a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs index ce21a00150117..58b94ff3314df 100644 --- a/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs +++ b/src/query/sql/src/planner/optimizer/heuristic/subquery_rewriter.rs @@ -406,6 +406,7 @@ impl SubqueryRewriter { }], from_distinct: false, mode: AggregateMode::Initial, + limit: None, }; let compare = ComparisonExpr { diff --git a/src/query/sql/src/planner/optimizer/rule/factory.rs b/src/query/sql/src/planner/optimizer/rule/factory.rs index 073d91b5f4c01..5cc126a34db29 100644 --- a/src/query/sql/src/planner/optimizer/rule/factory.rs +++ b/src/query/sql/src/planner/optimizer/rule/factory.rs @@ -20,6 +20,7 @@ use super::rewrite::RuleNormalizeDisjunctiveFilter; use super::rewrite::RuleNormalizeScalarFilter; use super::rewrite::RulePushDownFilterEvalScalar; use super::rewrite::RulePushDownFilterJoin; +use super::rewrite::RulePushDownLimitAggregate; use super::rewrite::RulePushDownLimitExpression; use super::transform::RuleCommuteJoin; use super::transform::RuleLeftAssociateJoin; @@ -62,6 +63,7 @@ impl RuleFactory { RuleID::PushDownLimitOuterJoin => Ok(Box::new(RulePushDownLimitOuterJoin::new())), RuleID::RulePushDownLimitExpression => Ok(Box::new(RulePushDownLimitExpression::new())), RuleID::PushDownLimitSort => Ok(Box::new(RulePushDownLimitSort::new())), + RuleID::PushDownLimitAggregate => Ok(Box::new(RulePushDownLimitAggregate::new())), RuleID::EliminateFilter => Ok(Box::new(RuleEliminateFilter::new())), RuleID::MergeEvalScalar => Ok(Box::new(RuleMergeEvalScalar::new())), RuleID::MergeFilter => Ok(Box::new(RuleMergeFilter::new())), diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs index 79c9d9de170d4..bbf3fcfea342d 100644 --- a/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/mod.rs @@ -23,6 +23,7 @@ mod rule_push_down_filter_eval_scalar; mod rule_push_down_filter_join; mod rule_push_down_filter_scan; mod rule_push_down_filter_union; +mod rule_push_down_limit_aggregate; mod rule_push_down_limit_expression; mod rule_push_down_limit_join; mod rule_push_down_limit_scan; @@ -43,6 +44,7 @@ pub use rule_push_down_filter_join::try_push_down_filter_join; pub use rule_push_down_filter_join::RulePushDownFilterJoin; pub use rule_push_down_filter_scan::RulePushDownFilterScan; pub use rule_push_down_filter_union::RulePushDownFilterUnion; +pub use rule_push_down_limit_aggregate::RulePushDownLimitAggregate; pub use rule_push_down_limit_expression::RulePushDownLimitExpression; pub use rule_push_down_limit_join::RulePushDownLimitOuterJoin; pub use rule_push_down_limit_scan::RulePushDownLimitScan; diff --git a/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs new file mode 100644 index 0000000000000..76f0c2a3b8d5b --- /dev/null +++ b/src/query/sql/src/planner/optimizer/rule/rewrite/rule_push_down_limit_aggregate.rs @@ -0,0 +1,91 @@ +// Copyright 2022 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp; + +use crate::optimizer::rule::Rule; +use crate::optimizer::rule::TransformResult; +use crate::optimizer::RuleID; +use crate::optimizer::SExpr; +use crate::plans::Aggregate; +use crate::plans::Limit; +use crate::plans::PatternPlan; +use crate::plans::RelOp; +use crate::plans::RelOp::Aggregate as OpAggregate; +use crate::plans::RelOp::Pattern; +use crate::plans::RelOperator; + +/// Input: Limit +/// \ +/// Aggregate +/// \ +/// * +/// +/// Output: Limit +/// \ +/// Aggregate(padding limit) +/// \ +/// * +pub struct RulePushDownLimitAggregate { + id: RuleID, + pattern: SExpr, +} + +impl RulePushDownLimitAggregate { + pub fn new() -> Self { + Self { + id: RuleID::PushDownLimitAggregate, + pattern: SExpr::create_unary( + PatternPlan { + plan_type: RelOp::Limit, + } + .into(), + SExpr::create_unary( + PatternPlan { + plan_type: OpAggregate, + } + .into(), + SExpr::create_leaf(PatternPlan { plan_type: Pattern }.into()), + ), + ), + } + } +} + +impl Rule for RulePushDownLimitAggregate { + fn id(&self) -> RuleID { + self.id + } + + fn apply(&self, s_expr: &SExpr, state: &mut TransformResult) -> common_exception::Result<()> { + let limit: Limit = s_expr.plan().clone().try_into()?; + if let Some(mut count) = limit.limit { + count += limit.offset; + let agg = s_expr.child(0)?; + let mut agg_limit: Aggregate = agg.plan().clone().try_into()?; + + agg_limit.limit = Some(agg_limit.limit.map_or(count, |c| cmp::max(c, count))); + let agg = SExpr::create_unary(RelOperator::Aggregate(agg_limit), agg.child(0)?.clone()); + + let mut result = SExpr::create_unary(limit.into(), agg); + result.set_applied_rule(&self.id); + state.add_result(result); + } + Ok(()) + } + + fn pattern(&self) -> &SExpr { + &self.pattern + } +} diff --git a/src/query/sql/src/planner/optimizer/rule/rule.rs b/src/query/sql/src/planner/optimizer/rule/rule.rs index 01e7f66977dff..ec95bd95f51f9 100644 --- a/src/query/sql/src/planner/optimizer/rule/rule.rs +++ b/src/query/sql/src/planner/optimizer/rule/rule.rs @@ -43,6 +43,7 @@ pub enum RuleID { PushDownLimitOuterJoin, RulePushDownLimitExpression, PushDownLimitSort, + PushDownLimitAggregate, PushDownLimitScan, PushDownSortScan, EliminateEvalScalar, @@ -73,6 +74,7 @@ impl Display for RuleID { RuleID::PushDownLimitOuterJoin => write!(f, "PushDownLimitOuterJoin"), RuleID::RulePushDownLimitExpression => write!(f, "PushDownLimitExpression"), RuleID::PushDownLimitSort => write!(f, "PushDownLimitSort"), + RuleID::PushDownLimitAggregate => write!(f, "PushDownLimitAggregate"), RuleID::PushDownLimitScan => write!(f, "PushDownLimitScan"), RuleID::PushDownSortScan => write!(f, "PushDownSortScan"), RuleID::EliminateEvalScalar => write!(f, "EliminateEvalScalar"), diff --git a/src/query/sql/src/planner/plans/aggregate.rs b/src/query/sql/src/planner/plans/aggregate.rs index cc16e2a80e35f..3627013f0b058 100644 --- a/src/query/sql/src/planner/plans/aggregate.rs +++ b/src/query/sql/src/planner/plans/aggregate.rs @@ -47,6 +47,7 @@ pub struct Aggregate { pub aggregate_functions: Vec, // True if the plan is generated from distinct, else the plan is a normal aggregate; pub from_distinct: bool, + pub limit: Option, } impl Aggregate {