Skip to content

Commit 7e76ade

Browse files
committed
feat: nan_value_counts support
1 parent efca9f0 commit 7e76ade

File tree

1 file changed

+73
-10
lines changed

1 file changed

+73
-10
lines changed

crates/iceberg/src/writer/file_writer/parquet_writer.rs

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717

1818
//! The module contains the file writer for parquet file format.
1919
20+
use std::collections::hash_map::Entry;
2021
use std::collections::HashMap;
2122
use std::sync::atomic::AtomicI64;
2223
use std::sync::Arc;
2324

24-
use arrow_schema::SchemaRef as ArrowSchemaRef;
25+
use arrow_array::Float32Array;
26+
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
2527
use bytes::Bytes;
2628
use futures::future::BoxFuture;
2729
use itertools::Itertools;
@@ -101,6 +103,7 @@ impl<T: LocationGenerator, F: FileNameGenerator> FileWriterBuilder for ParquetWr
101103
written_size,
102104
current_row_num: 0,
103105
out_file,
106+
nan_value_counts: HashMap::new(),
104107
})
105108
}
106109
}
@@ -216,6 +219,7 @@ pub struct ParquetWriter {
216219
writer: AsyncArrowWriter<AsyncFileWriter<TrackWriter>>,
217220
written_size: Arc<AtomicI64>,
218221
current_row_num: usize,
222+
nan_value_counts: HashMap<i32, u64>,
219223
}
220224

221225
/// Used to aggregate min and max value of each column.
@@ -312,6 +316,7 @@ impl ParquetWriter {
312316
metadata: FileMetaData,
313317
written_size: usize,
314318
file_path: String,
319+
nan_value_counts: HashMap<i32, u64>,
315320
) -> Result<DataFileBuilder> {
316321
let index_by_parquet_path = {
317322
let mut visitor = IndexByParquetPathName::new();
@@ -378,8 +383,8 @@ impl ParquetWriter {
378383
.null_value_counts(null_value_counts)
379384
.lower_bounds(lower_bounds)
380385
.upper_bounds(upper_bounds)
386+
.nan_value_counts(nan_value_counts)
381387
// # TODO(#417)
382-
// - nan_value_counts
383388
// - distinct_counts
384389
.key_metadata(metadata.footer_signing_key_metadata)
385390
.split_offsets(
@@ -396,13 +401,45 @@ impl ParquetWriter {
396401
impl FileWriter for ParquetWriter {
397402
async fn write(&mut self, batch: &arrow_array::RecordBatch) -> crate::Result<()> {
398403
self.current_row_num += batch.num_rows();
404+
405+
for (col, field) in batch
406+
.columns()
407+
.iter()
408+
.zip(self.schema.as_struct().fields().iter())
409+
{
410+
let dt = col.data_type();
411+
412+
let nan_val_cnt: u64 = match dt {
413+
DataType::Float32 => {
414+
let float_array = col.as_any().downcast_ref::<Float32Array>().unwrap();
415+
416+
float_array
417+
.iter()
418+
.filter(|value| value.map_or(false, |v| v.is_nan()))
419+
.count() as u64
420+
}
421+
_ => 0,
422+
};
423+
424+
match self.nan_value_counts.entry(field.id) {
425+
Entry::Occupied(mut ele) => {
426+
let total_nan_val_cnt = ele.get() + nan_val_cnt;
427+
ele.insert(total_nan_val_cnt);
428+
}
429+
Entry::Vacant(v) => {
430+
v.insert(nan_val_cnt);
431+
}
432+
}
433+
}
434+
399435
self.writer.write(batch).await.map_err(|err| {
400436
Error::new(
401437
ErrorKind::Unexpected,
402438
"Failed to write using parquet writer.",
403439
)
404440
.with_source(err)
405441
})?;
442+
406443
Ok(())
407444
}
408445

@@ -418,6 +455,7 @@ impl FileWriter for ParquetWriter {
418455
metadata,
419456
written_size as usize,
420457
self.out_file.location().to_string(),
458+
self.nan_value_counts,
421459
)?])
422460
}
423461
}
@@ -478,8 +516,8 @@ mod tests {
478516
use anyhow::Result;
479517
use arrow_array::types::Int64Type;
480518
use arrow_array::{
481-
Array, ArrayRef, BooleanArray, Decimal128Array, Int32Array, Int64Array, ListArray,
482-
RecordBatch, StructArray,
519+
Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Int32Array, Int64Array,
520+
ListArray, RecordBatch, StructArray,
483521
};
484522
use arrow_schema::{DataType, SchemaRef as ArrowSchemaRef};
485523
use arrow_select::concat::concat_batches;
@@ -659,13 +697,27 @@ mod tests {
659697
arrow_schema::Field::new("col", arrow_schema::DataType::Int64, true).with_metadata(
660698
HashMap::from([(PARQUET_FIELD_ID_META_KEY.to_string(), "0".to_string())]),
661699
),
700+
arrow_schema::Field::new("col1", arrow_schema::DataType::Float32, true)
701+
.with_metadata(HashMap::from([(
702+
PARQUET_FIELD_ID_META_KEY.to_string(),
703+
"1".to_string(),
704+
)])),
662705
];
663706
Arc::new(arrow_schema::Schema::new(fields))
664707
};
665708
let col = Arc::new(Int64Array::from_iter_values(0..1024)) as ArrayRef;
666709
let null_col = Arc::new(Int64Array::new_null(1024)) as ArrayRef;
667-
let to_write = RecordBatch::try_new(schema.clone(), vec![col]).unwrap();
668-
let to_write_null = RecordBatch::try_new(schema.clone(), vec![null_col]).unwrap();
710+
let float_col = Arc::new(Float32Array::from_iter_values((0..1024).map(|x| {
711+
if x % 100 == 0 {
712+
// There will be 11 NANs as there are 1024 entries
713+
f32::NAN
714+
} else {
715+
x as f32
716+
}
717+
}))) as ArrayRef;
718+
let to_write = RecordBatch::try_new(schema.clone(), vec![col, float_col.clone()]).unwrap();
719+
let to_write_null =
720+
RecordBatch::try_new(schema.clone(), vec![null_col, float_col]).unwrap();
669721

670722
// write data
671723
let mut pw = ParquetWriterBuilder::new(
@@ -677,6 +729,7 @@ mod tests {
677729
)
678730
.build()
679731
.await?;
732+
680733
pw.write(&to_write).await?;
681734
pw.write(&to_write_null).await?;
682735
let res = pw.close().await?;
@@ -693,16 +746,26 @@ mod tests {
693746

694747
// check data file
695748
assert_eq!(data_file.record_count(), 2048);
696-
assert_eq!(*data_file.value_counts(), HashMap::from([(0, 2048)]));
749+
assert_eq!(
750+
*data_file.value_counts(),
751+
HashMap::from([(0, 2048), (1, 2048)])
752+
);
697753
assert_eq!(
698754
*data_file.lower_bounds(),
699-
HashMap::from([(0, Datum::long(0))])
755+
HashMap::from([(0, Datum::long(0)), (1, Datum::float(1.0))])
700756
);
701757
assert_eq!(
702758
*data_file.upper_bounds(),
703-
HashMap::from([(0, Datum::long(1023))])
759+
HashMap::from([(0, Datum::long(1023)), (1, Datum::float(1023.0))])
760+
);
761+
assert_eq!(
762+
*data_file.null_value_counts(),
763+
HashMap::from([(0, 1024), (1, 0)])
764+
);
765+
assert_eq!(
766+
*data_file.nan_value_counts(),
767+
HashMap::from([(0, 0), (1, 22)]) // 22, cause we wrote float column twice
704768
);
705-
assert_eq!(*data_file.null_value_counts(), HashMap::from([(0, 1024)]));
706769

707770
// check the written file
708771
let expect_batch = concat_batches(&schema, vec![&to_write, &to_write_null]).unwrap();

0 commit comments

Comments
 (0)