Skip to content

Commit ee58ac3

Browse files
committed
add test case for spill compression
1 parent ff261d4 commit ee58ac3

File tree

4 files changed

+125
-8
lines changed

4 files changed

+125
-8
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ arrow-flight = { version = "55.1.0", features = [
9797
"flight-sql-experimental",
9898
] }
9999
arrow-ipc = { version = "55.0.0", default-features = false, features = [
100-
"lz4",
100+
"lz4", "zstd",
101101
] }
102102
arrow-ord = { version = "55.0.0", default-features = false }
103103
arrow-schema = { version = "55.0.0", default-features = false }

datafusion/common/src/config.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@ config_namespace! {
280280
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
281281
pub enum SpillCompression {
282282
Zstd,
283-
Lz4_frame,
283+
Lz4Frame,
284284
Uncompressed,
285285
}
286286

@@ -290,7 +290,7 @@ impl FromStr for SpillCompression {
290290
fn from_str(s: &str) -> Result<Self, Self::Err> {
291291
match s.to_ascii_lowercase().as_str() {
292292
"zstd" => Ok(Self::Zstd),
293-
"lz4_frame" => Ok(Self::Lz4_frame),
293+
"lz4_frame" => Ok(Self::Lz4Frame),
294294
"uncompressed" | "" => Ok(Self::Uncompressed),
295295
other => Err(DataFusionError::Execution(format!(
296296
"Invalid Spill file compression type: {other}. Expected one of: zstd, lz4, uncompressed"
@@ -313,8 +313,8 @@ impl ConfigField for SpillCompression {
313313
impl Display for SpillCompression {
314314
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
315315
let str = match self {
316-
Self::Zstd => "Zstd",
317-
Self::Lz4_frame => "Lz4",
316+
Self::Zstd => "zstd",
317+
Self::Lz4Frame => "lz4_frame",
318318
Self::Uncompressed => "",
319319
};
320320
write!(f, "{str}")
@@ -325,7 +325,7 @@ impl From<SpillCompression> for Option<CompressionType> {
325325
fn from(c: SpillCompression) -> Self {
326326
match c {
327327
SpillCompression::Zstd => Some(CompressionType::ZSTD),
328-
SpillCompression::Lz4_frame => Some(CompressionType::LZ4_FRAME),
328+
SpillCompression::Lz4Frame => Some(CompressionType::LZ4_FRAME),
329329
SpillCompression::Uncompressed => None,
330330
}
331331
}

datafusion/physical-plan/benches/spill_io.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ fn bench_spill_io(c: &mut Criterion) {
8383
Field::new("c2", DataType::Date32, true),
8484
Field::new("c3", DataType::Decimal128(11, 2), true),
8585
]));
86-
let spill_manager = SpillManager::new(env, metrics, schema, SpillCompression::Uncompressed);
86+
let spill_manager =
87+
SpillManager::new(env, metrics, schema, SpillCompression::Uncompressed);
8788

8889
let mut group = c.benchmark_group("spill_io");
8990
let rt = Runtime::new().unwrap();

datafusion/physical-plan/src/spill/mod.rs

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use std::task::{Context, Poll};
3131
use arrow::array::ArrayData;
3232
use arrow::datatypes::{Schema, SchemaRef};
3333
use arrow::ipc::writer::IpcWriteOptions;
34+
use arrow::ipc::MetadataVersion;
3435
use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
35-
use arrow::ipc::{CompressionType, MetadataVersion};
3636
use arrow::record_batch::RecordBatch;
3737

3838
use datafusion_common::config::SpillCompression;
@@ -349,6 +349,7 @@ mod tests {
349349
use crate::metrics::SpillMetrics;
350350
use crate::spill::spill_manager::SpillManager;
351351
use crate::test::build_table_i32;
352+
use arrow::array::ArrayRef;
352353
use arrow::array::{Float64Array, Int32Array, ListArray, StringArray};
353354
use arrow::compute::cast;
354355
use arrow::datatypes::{DataType, Field, Int32Type, Schema};
@@ -502,6 +503,121 @@ mod tests {
502503
Ok(())
503504
}
504505

506+
fn build_compressible_batch() -> RecordBatch {
507+
let schema = Arc::new(Schema::new(vec![
508+
Field::new("a", DataType::Utf8, false),
509+
Field::new("b", DataType::Int32, false),
510+
Field::new("c", DataType::Int32, true),
511+
]));
512+
513+
let a: ArrayRef = Arc::new(StringArray::from_iter_values(
514+
std::iter::repeat("repeated").take(100),
515+
));
516+
let b: ArrayRef = Arc::new(Int32Array::from(vec![1; 100]));
517+
let c: ArrayRef = Arc::new(Int32Array::from(vec![2; 100]));
518+
519+
RecordBatch::try_new(schema, vec![a, b, c]).unwrap()
520+
}
521+
522+
async fn validate(
523+
spill_manager: &SpillManager,
524+
spill_file: RefCountedTempFile,
525+
num_rows: usize,
526+
schema: SchemaRef,
527+
batch_count: usize,
528+
) -> Result<()> {
529+
let spilled_rows = spill_manager.metrics.spilled_rows.value();
530+
assert_eq!(spilled_rows, num_rows);
531+
532+
let stream = spill_manager.read_spill_as_stream(spill_file)?;
533+
assert_eq!(stream.schema(), schema);
534+
535+
let batches = collect(stream).await?;
536+
assert_eq!(batches.len(), batch_count);
537+
538+
Ok(())
539+
}
540+
541+
#[tokio::test]
542+
async fn test_spill_compression() -> Result<()> {
543+
let batch = build_compressible_batch();
544+
let num_rows = batch.num_rows();
545+
let schema = batch.schema();
546+
let batch_count = 1;
547+
let batches = [batch];
548+
549+
// Construct SpillManager
550+
let env = Arc::new(RuntimeEnv::default());
551+
let uncompressed_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
552+
let lz4_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
553+
let zstd_metrics = SpillMetrics::new(&ExecutionPlanMetricsSet::new(), 0);
554+
let uncompressed_spill_manager = SpillManager::new(
555+
env.clone(),
556+
uncompressed_metrics,
557+
Arc::clone(&schema),
558+
SpillCompression::Uncompressed,
559+
);
560+
let lz4_spill_manager = SpillManager::new(
561+
env.clone(),
562+
lz4_metrics,
563+
Arc::clone(&schema),
564+
SpillCompression::Lz4Frame,
565+
);
566+
let zstd_spill_manager = SpillManager::new(
567+
env,
568+
zstd_metrics,
569+
Arc::clone(&schema),
570+
SpillCompression::Zstd,
571+
);
572+
let uncompressed_spill_file = uncompressed_spill_manager
573+
.spill_record_batch_and_finish(&batches, "Test")?
574+
.unwrap();
575+
let lz4_spill_file = lz4_spill_manager
576+
.spill_record_batch_and_finish(&batches, "Lz4_Test")?
577+
.unwrap();
578+
let zstd_spill_file = zstd_spill_manager
579+
.spill_record_batch_and_finish(&batches, "ZSTD_Test")?
580+
.unwrap();
581+
assert!(uncompressed_spill_file.path().exists());
582+
assert!(lz4_spill_file.path().exists());
583+
assert!(zstd_spill_file.path().exists());
584+
585+
let lz4_spill_size = std::fs::metadata(lz4_spill_file.path())?.len();
586+
let zstd_spill_size = std::fs::metadata(zstd_spill_file.path())?.len();
587+
let uncompressed_spill_size =
588+
std::fs::metadata(uncompressed_spill_file.path())?.len();
589+
590+
assert!(uncompressed_spill_size > lz4_spill_size);
591+
assert!(uncompressed_spill_size > zstd_spill_size);
592+
593+
// TODO validate with function
594+
validate(
595+
&lz4_spill_manager,
596+
lz4_spill_file,
597+
num_rows,
598+
schema.clone(),
599+
batch_count,
600+
)
601+
.await?;
602+
validate(
603+
&zstd_spill_manager,
604+
zstd_spill_file,
605+
num_rows,
606+
schema.clone(),
607+
batch_count,
608+
)
609+
.await?;
610+
validate(
611+
&uncompressed_spill_manager,
612+
uncompressed_spill_file,
613+
num_rows,
614+
schema,
615+
batch_count,
616+
)
617+
.await?;
618+
Ok(())
619+
}
620+
505621
#[test]
506622
fn test_get_record_batch_memory_size() {
507623
// Create a simple record batch with two columns

0 commit comments

Comments
 (0)