@@ -31,8 +31,8 @@ use std::task::{Context, Poll};
3131use arrow:: array:: ArrayData ;
3232use arrow:: datatypes:: { Schema , SchemaRef } ;
3333use arrow:: ipc:: writer:: IpcWriteOptions ;
34+ use arrow:: ipc:: MetadataVersion ;
3435use arrow:: ipc:: { reader:: StreamReader , writer:: StreamWriter } ;
35- use arrow:: ipc:: { CompressionType , MetadataVersion } ;
3636use arrow:: record_batch:: RecordBatch ;
3737
3838use datafusion_common:: config:: SpillCompression ;
@@ -349,6 +349,7 @@ mod tests {
349349 use crate :: metrics:: SpillMetrics ;
350350 use crate :: spill:: spill_manager:: SpillManager ;
351351 use crate :: test:: build_table_i32;
352+ use arrow:: array:: ArrayRef ;
352353 use arrow:: array:: { Float64Array , Int32Array , ListArray , StringArray } ;
353354 use arrow:: compute:: cast;
354355 use arrow:: datatypes:: { DataType , Field , Int32Type , Schema } ;
@@ -502,6 +503,121 @@ mod tests {
502503 Ok ( ( ) )
503504 }
504505
506+ fn build_compressible_batch ( ) -> RecordBatch {
507+ let schema = Arc :: new ( Schema :: new ( vec ! [
508+ Field :: new( "a" , DataType :: Utf8 , false ) ,
509+ Field :: new( "b" , DataType :: Int32 , false ) ,
510+ Field :: new( "c" , DataType :: Int32 , true ) ,
511+ ] ) ) ;
512+
513+ let a: ArrayRef = Arc :: new ( StringArray :: from_iter_values (
514+ std:: iter:: repeat ( "repeated" ) . take ( 100 ) ,
515+ ) ) ;
516+ let b: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 1 ; 100 ] ) ) ;
517+ let c: ArrayRef = Arc :: new ( Int32Array :: from ( vec ! [ 2 ; 100 ] ) ) ;
518+
519+ RecordBatch :: try_new ( schema, vec ! [ a, b, c] ) . unwrap ( )
520+ }
521+
522+ async fn validate (
523+ spill_manager : & SpillManager ,
524+ spill_file : RefCountedTempFile ,
525+ num_rows : usize ,
526+ schema : SchemaRef ,
527+ batch_count : usize ,
528+ ) -> Result < ( ) > {
529+ let spilled_rows = spill_manager. metrics . spilled_rows . value ( ) ;
530+ assert_eq ! ( spilled_rows, num_rows) ;
531+
532+ let stream = spill_manager. read_spill_as_stream ( spill_file) ?;
533+ assert_eq ! ( stream. schema( ) , schema) ;
534+
535+ let batches = collect ( stream) . await ?;
536+ assert_eq ! ( batches. len( ) , batch_count) ;
537+
538+ Ok ( ( ) )
539+ }
540+
541+ #[ tokio:: test]
542+ async fn test_spill_compression ( ) -> Result < ( ) > {
543+ let batch = build_compressible_batch ( ) ;
544+ let num_rows = batch. num_rows ( ) ;
545+ let schema = batch. schema ( ) ;
546+ let batch_count = 1 ;
547+ let batches = [ batch] ;
548+
549+ // Construct SpillManager
550+ let env = Arc :: new ( RuntimeEnv :: default ( ) ) ;
551+ let uncompressed_metrics = SpillMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ;
552+ let lz4_metrics = SpillMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ;
553+ let zstd_metrics = SpillMetrics :: new ( & ExecutionPlanMetricsSet :: new ( ) , 0 ) ;
554+ let uncompressed_spill_manager = SpillManager :: new (
555+ env. clone ( ) ,
556+ uncompressed_metrics,
557+ Arc :: clone ( & schema) ,
558+ SpillCompression :: Uncompressed ,
559+ ) ;
560+ let lz4_spill_manager = SpillManager :: new (
561+ env. clone ( ) ,
562+ lz4_metrics,
563+ Arc :: clone ( & schema) ,
564+ SpillCompression :: Lz4Frame ,
565+ ) ;
566+ let zstd_spill_manager = SpillManager :: new (
567+ env,
568+ zstd_metrics,
569+ Arc :: clone ( & schema) ,
570+ SpillCompression :: Zstd ,
571+ ) ;
572+ let uncompressed_spill_file = uncompressed_spill_manager
573+ . spill_record_batch_and_finish ( & batches, "Test" ) ?
574+ . unwrap ( ) ;
575+ let lz4_spill_file = lz4_spill_manager
576+ . spill_record_batch_and_finish ( & batches, "Lz4_Test" ) ?
577+ . unwrap ( ) ;
578+ let zstd_spill_file = zstd_spill_manager
579+ . spill_record_batch_and_finish ( & batches, "ZSTD_Test" ) ?
580+ . unwrap ( ) ;
581+ assert ! ( uncompressed_spill_file. path( ) . exists( ) ) ;
582+ assert ! ( lz4_spill_file. path( ) . exists( ) ) ;
583+ assert ! ( zstd_spill_file. path( ) . exists( ) ) ;
584+
585+ let lz4_spill_size = std:: fs:: metadata ( lz4_spill_file. path ( ) ) ?. len ( ) ;
586+ let zstd_spill_size = std:: fs:: metadata ( zstd_spill_file. path ( ) ) ?. len ( ) ;
587+ let uncompressed_spill_size =
588+ std:: fs:: metadata ( uncompressed_spill_file. path ( ) ) ?. len ( ) ;
589+
590+ assert ! ( uncompressed_spill_size > lz4_spill_size) ;
591+ assert ! ( uncompressed_spill_size > zstd_spill_size) ;
592+
593+ // TODO validate with function
594+ validate (
595+ & lz4_spill_manager,
596+ lz4_spill_file,
597+ num_rows,
598+ schema. clone ( ) ,
599+ batch_count,
600+ )
601+ . await ?;
602+ validate (
603+ & zstd_spill_manager,
604+ zstd_spill_file,
605+ num_rows,
606+ schema. clone ( ) ,
607+ batch_count,
608+ )
609+ . await ?;
610+ validate (
611+ & uncompressed_spill_manager,
612+ uncompressed_spill_file,
613+ num_rows,
614+ schema,
615+ batch_count,
616+ )
617+ . await ?;
618+ Ok ( ( ) )
619+ }
620+
505621 #[ test]
506622 fn test_get_record_batch_memory_size ( ) {
507623 // Create a simple record batch with two columns
0 commit comments