1717
1818//! The module contains the file writer for parquet file format.
1919
20+ use std:: collections:: hash_map:: Entry ;
2021use std:: collections:: HashMap ;
2122use std:: sync:: atomic:: AtomicI64 ;
2223use std:: sync:: Arc ;
2324
24- use arrow_schema:: SchemaRef as ArrowSchemaRef ;
25+ use arrow_array:: Float32Array ;
26+ use arrow_schema:: { DataType , SchemaRef as ArrowSchemaRef } ;
2527use bytes:: Bytes ;
2628use futures:: future:: BoxFuture ;
2729use itertools:: Itertools ;
@@ -101,6 +103,7 @@ impl<T: LocationGenerator, F: FileNameGenerator> FileWriterBuilder for ParquetWr
101103 written_size,
102104 current_row_num : 0 ,
103105 out_file,
106+ nan_value_counts : HashMap :: new ( ) ,
104107 } )
105108 }
106109}
@@ -216,6 +219,7 @@ pub struct ParquetWriter {
216219 writer : AsyncArrowWriter < AsyncFileWriter < TrackWriter > > ,
217220 written_size : Arc < AtomicI64 > ,
218221 current_row_num : usize ,
222+ nan_value_counts : HashMap < i32 , u64 > ,
219223}
220224
221225/// Used to aggregate min and max value of each column.
@@ -312,6 +316,7 @@ impl ParquetWriter {
312316 metadata : FileMetaData ,
313317 written_size : usize ,
314318 file_path : String ,
319+ nan_value_counts : HashMap < i32 , u64 > ,
315320 ) -> Result < DataFileBuilder > {
316321 let index_by_parquet_path = {
317322 let mut visitor = IndexByParquetPathName :: new ( ) ;
@@ -378,8 +383,8 @@ impl ParquetWriter {
378383 . null_value_counts ( null_value_counts)
379384 . lower_bounds ( lower_bounds)
380385 . upper_bounds ( upper_bounds)
386+ . nan_value_counts ( nan_value_counts)
381387 // # TODO(#417)
382- // - nan_value_counts
383388 // - distinct_counts
384389 . key_metadata ( metadata. footer_signing_key_metadata )
385390 . split_offsets (
@@ -396,13 +401,45 @@ impl ParquetWriter {
396401impl FileWriter for ParquetWriter {
397402 async fn write ( & mut self , batch : & arrow_array:: RecordBatch ) -> crate :: Result < ( ) > {
398403 self . current_row_num += batch. num_rows ( ) ;
404+
405+ for ( col, field) in batch
406+ . columns ( )
407+ . iter ( )
408+ . zip ( self . schema . as_struct ( ) . fields ( ) . iter ( ) )
409+ {
410+ let dt = col. data_type ( ) ;
411+
412+ let nan_val_cnt: u64 = match dt {
413+ DataType :: Float32 => {
414+ let float_array = col. as_any ( ) . downcast_ref :: < Float32Array > ( ) . unwrap ( ) ;
415+
416+ float_array
417+ . iter ( )
418+ . filter ( |value| value. map_or ( false , |v| v. is_nan ( ) ) )
419+ . count ( ) as u64
420+ }
421+ _ => 0 ,
422+ } ;
423+
424+ match self . nan_value_counts . entry ( field. id ) {
425+ Entry :: Occupied ( mut ele) => {
426+ let total_nan_val_cnt = ele. get ( ) + nan_val_cnt;
427+ ele. insert ( total_nan_val_cnt) ;
428+ }
429+ Entry :: Vacant ( v) => {
430+ v. insert ( nan_val_cnt) ;
431+ }
432+ }
433+ }
434+
399435 self . writer . write ( batch) . await . map_err ( |err| {
400436 Error :: new (
401437 ErrorKind :: Unexpected ,
402438 "Failed to write using parquet writer." ,
403439 )
404440 . with_source ( err)
405441 } ) ?;
442+
406443 Ok ( ( ) )
407444 }
408445
@@ -418,6 +455,7 @@ impl FileWriter for ParquetWriter {
418455 metadata,
419456 written_size as usize ,
420457 self . out_file. location( ) . to_string( ) ,
458+ self . nan_value_counts,
421459 ) ?] )
422460 }
423461}
@@ -478,8 +516,8 @@ mod tests {
478516 use anyhow:: Result ;
479517 use arrow_array:: types:: Int64Type ;
480518 use arrow_array:: {
481- Array , ArrayRef , BooleanArray , Decimal128Array , Int32Array , Int64Array , ListArray ,
482- RecordBatch , StructArray ,
519+ Array , ArrayRef , BooleanArray , Decimal128Array , Float32Array , Int32Array , Int64Array ,
520+ ListArray , RecordBatch , StructArray ,
483521 } ;
484522 use arrow_schema:: { DataType , SchemaRef as ArrowSchemaRef } ;
485523 use arrow_select:: concat:: concat_batches;
@@ -659,13 +697,27 @@ mod tests {
659697 arrow_schema:: Field :: new( "col" , arrow_schema:: DataType :: Int64 , true ) . with_metadata(
660698 HashMap :: from( [ ( PARQUET_FIELD_ID_META_KEY . to_string( ) , "0" . to_string( ) ) ] ) ,
661699 ) ,
700+ arrow_schema:: Field :: new( "col1" , arrow_schema:: DataType :: Float32 , true )
701+ . with_metadata( HashMap :: from( [ (
702+ PARQUET_FIELD_ID_META_KEY . to_string( ) ,
703+ "1" . to_string( ) ,
704+ ) ] ) ) ,
662705 ] ;
663706 Arc :: new ( arrow_schema:: Schema :: new ( fields) )
664707 } ;
665708 let col = Arc :: new ( Int64Array :: from_iter_values ( 0 ..1024 ) ) as ArrayRef ;
666709 let null_col = Arc :: new ( Int64Array :: new_null ( 1024 ) ) as ArrayRef ;
667- let to_write = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ col] ) . unwrap ( ) ;
668- let to_write_null = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ null_col] ) . unwrap ( ) ;
710+ let float_col = Arc :: new ( Float32Array :: from_iter_values ( ( 0 ..1024 ) . map ( |x| {
711+ if x % 100 == 0 {
712+ // There will be 11 NANs as there are 1024 entries
713+ f32:: NAN
714+ } else {
715+ x as f32
716+ }
717+ } ) ) ) as ArrayRef ;
718+ let to_write = RecordBatch :: try_new ( schema. clone ( ) , vec ! [ col, float_col. clone( ) ] ) . unwrap ( ) ;
719+ let to_write_null =
720+ RecordBatch :: try_new ( schema. clone ( ) , vec ! [ null_col, float_col] ) . unwrap ( ) ;
669721
670722 // write data
671723 let mut pw = ParquetWriterBuilder :: new (
@@ -677,6 +729,7 @@ mod tests {
677729 )
678730 . build ( )
679731 . await ?;
732+
680733 pw. write ( & to_write) . await ?;
681734 pw. write ( & to_write_null) . await ?;
682735 let res = pw. close ( ) . await ?;
@@ -693,16 +746,26 @@ mod tests {
693746
694747 // check data file
695748 assert_eq ! ( data_file. record_count( ) , 2048 ) ;
696- assert_eq ! ( * data_file. value_counts( ) , HashMap :: from( [ ( 0 , 2048 ) ] ) ) ;
749+ assert_eq ! (
750+ * data_file. value_counts( ) ,
751+ HashMap :: from( [ ( 0 , 2048 ) , ( 1 , 2048 ) ] )
752+ ) ;
697753 assert_eq ! (
698754 * data_file. lower_bounds( ) ,
699- HashMap :: from( [ ( 0 , Datum :: long( 0 ) ) ] )
755+ HashMap :: from( [ ( 0 , Datum :: long( 0 ) ) , ( 1 , Datum :: float ( 1.0 ) ) ] )
700756 ) ;
701757 assert_eq ! (
702758 * data_file. upper_bounds( ) ,
703- HashMap :: from( [ ( 0 , Datum :: long( 1023 ) ) ] )
759+ HashMap :: from( [ ( 0 , Datum :: long( 1023 ) ) , ( 1 , Datum :: float( 1023.0 ) ) ] )
760+ ) ;
761+ assert_eq ! (
762+ * data_file. null_value_counts( ) ,
763+ HashMap :: from( [ ( 0 , 1024 ) , ( 1 , 0 ) ] )
764+ ) ;
765+ assert_eq ! (
766+ * data_file. nan_value_counts( ) ,
767+ HashMap :: from( [ ( 0 , 0 ) , ( 1 , 22 ) ] ) // 22, cause we wrote float column twice
704768 ) ;
705- assert_eq ! ( * data_file. null_value_counts( ) , HashMap :: from( [ ( 0 , 1024 ) ] ) ) ;
706769
707770 // check the written file
708771 let expect_batch = concat_batches ( & schema, vec ! [ & to_write, & to_write_null] ) . unwrap ( ) ;
0 commit comments