@@ -27,7 +27,8 @@ use arrow::{downcast_dictionary_array, downcast_primitive_array};
2727use arrow_buffer:: i256;
2828
2929use crate :: cast:: {
30- as_boolean_array, as_generic_binary_array, as_primitive_array, as_string_array,
30+ as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array,
31+ as_primitive_array, as_string_array, as_struct_array,
3132} ;
3233use crate :: error:: { DataFusionError , Result , _internal_err} ;
3334
@@ -207,6 +208,35 @@ fn hash_dictionary<K: ArrowDictionaryKeyType>(
207208 Ok ( ( ) )
208209}
209210
211+ fn hash_struct_array (
212+ array : & StructArray ,
213+ random_state : & RandomState ,
214+ hashes_buffer : & mut [ u64 ] ,
215+ ) -> Result < ( ) > {
216+ let nulls = array. nulls ( ) ;
217+ let num_columns = array. num_columns ( ) ;
218+
219+ // Skip null columns
220+ let valid_indices: Vec < usize > = if let Some ( nulls) = nulls {
221+ nulls. valid_indices ( ) . collect ( )
222+ } else {
223+ ( 0 ..num_columns) . collect ( )
224+ } ;
225+
226+ // Create hashes for each row that combines the hashes over all the column at that row.
227+ // array.len() is the number of rows.
228+ let mut values_hashes = vec ! [ 0u64 ; array. len( ) ] ;
229+ create_hashes ( array. columns ( ) , random_state, & mut values_hashes) ?;
230+
231+ // Skip the null columns, nulls should get hash value 0.
232+ for i in valid_indices {
233+ let hash = & mut hashes_buffer[ i] ;
234+ * hash = combine_hashes ( * hash, values_hashes[ i] ) ;
235+ }
236+
237+ Ok ( ( ) )
238+ }
239+
210240fn hash_list_array < OffsetSize > (
211241 array : & GenericListArray < OffsetSize > ,
212242 random_state : & RandomState ,
@@ -327,12 +357,16 @@ pub fn create_hashes<'a>(
327357 array => hash_dictionary( array, random_state, hashes_buffer, rehash) ?,
328358 _ => unreachable!( )
329359 }
360+ DataType :: Struct ( _) => {
361+ let array = as_struct_array( array) ?;
362+ hash_struct_array( array, random_state, hashes_buffer) ?;
363+ }
330364 DataType :: List ( _) => {
331- let array = as_list_array( array) ;
365+ let array = as_list_array( array) ? ;
332366 hash_list_array( array, random_state, hashes_buffer) ?;
333367 }
334368 DataType :: LargeList ( _) => {
335- let array = as_large_list_array( array) ;
369+ let array = as_large_list_array( array) ? ;
336370 hash_list_array( array, random_state, hashes_buffer) ?;
337371 }
338372 _ => {
@@ -515,6 +549,58 @@ mod tests {
515549 assert_eq ! ( hashes[ 2 ] , hashes[ 3 ] ) ;
516550 }
517551
552+ #[ test]
553+ // Tests actual values of hashes, which are different if forcing collisions
554+ #[ cfg( not( feature = "force_hash_collisions" ) ) ]
555+ fn create_hashes_for_struct_arrays ( ) {
556+ use arrow_buffer:: Buffer ;
557+
558+ let boolarr = Arc :: new ( BooleanArray :: from ( vec ! [
559+ false , false , true , true , true , true ,
560+ ] ) ) ;
561+ let i32arr = Arc :: new ( Int32Array :: from ( vec ! [ 10 , 10 , 20 , 20 , 30 , 31 ] ) ) ;
562+
563+ let struct_array = StructArray :: from ( (
564+ vec ! [
565+ (
566+ Arc :: new( Field :: new( "bool" , DataType :: Boolean , false ) ) ,
567+ boolarr. clone( ) as ArrayRef ,
568+ ) ,
569+ (
570+ Arc :: new( Field :: new( "i32" , DataType :: Int32 , false ) ) ,
571+ i32arr. clone( ) as ArrayRef ,
572+ ) ,
573+ (
574+ Arc :: new( Field :: new( "i32" , DataType :: Int32 , false ) ) ,
575+ i32arr. clone( ) as ArrayRef ,
576+ ) ,
577+ (
578+ Arc :: new( Field :: new( "bool" , DataType :: Boolean , false ) ) ,
579+ boolarr. clone( ) as ArrayRef ,
580+ ) ,
581+ ] ,
582+ Buffer :: from ( & [ 0b001011 ] ) ,
583+ ) ) ;
584+
585+ assert ! ( struct_array. is_valid( 0 ) ) ;
586+ assert ! ( struct_array. is_valid( 1 ) ) ;
587+ assert ! ( struct_array. is_null( 2 ) ) ;
588+ assert ! ( struct_array. is_valid( 3 ) ) ;
589+ assert ! ( struct_array. is_null( 4 ) ) ;
590+ assert ! ( struct_array. is_null( 5 ) ) ;
591+
592+ let array = Arc :: new ( struct_array) as ArrayRef ;
593+
594+ let random_state = RandomState :: with_seeds ( 0 , 0 , 0 , 0 ) ;
595+ let mut hashes = vec ! [ 0 ; array. len( ) ] ;
596+ create_hashes ( & [ array] , & random_state, & mut hashes) . unwrap ( ) ;
597+ assert_eq ! ( hashes[ 0 ] , hashes[ 1 ] ) ;
598+ // same value but the third row ( hashes[2] ) is null
599+ assert_ne ! ( hashes[ 2 ] , hashes[ 3 ] ) ;
600+ // different values but both are null
601+ assert_eq ! ( hashes[ 4 ] , hashes[ 5 ] ) ;
602+ }
603+
518604 #[ test]
519605 // Tests actual values of hashes, which are different if forcing collisions
520606 #[ cfg( not( feature = "force_hash_collisions" ) ) ]
0 commit comments