diff --git a/arrow/src/csv/writer.rs b/arrow/src/csv/writer.rs index e9d8565b2a5b..f2f4ce813c65 100644 --- a/arrow/src/csv/writer.rs +++ b/arrow/src/csv/writer.rs @@ -128,13 +128,13 @@ impl Writer { /// Convert a record to a string vector fn convert( &self, - batch: &RecordBatch, + batch: &[ArrayRef], row_index: usize, buffer: &mut [String], ) -> Result<()> { // TODO: it'd be more efficient if we could create `record: Vec<&[u8]> for (col_index, item) in buffer.iter_mut().enumerate() { - let col = batch.column(col_index); + let col = &batch[col_index]; if col.is_null(row_index) { // write an empty value *item = "".to_string(); @@ -274,10 +274,22 @@ impl Writer { self.beginning = false; } + let columns: Vec<_> = batch + .columns() + .iter() + .map(|array| match array.data_type() { + DataType::Dictionary(_, value_type) => { + crate::compute::kernels::cast::cast(array, &value_type) + .expect("cannot cast dictionary to underlying values") + } + _ => array.clone(), + }) + .collect(); + let mut buffer = vec!["".to_string(); batch.num_columns()]; for row_index in 0..batch.num_rows() { - self.convert(batch, row_index, &mut buffer)?; + self.convert(columns.as_slice(), row_index, &mut buffer)?; self.writer.write_record(&buffer)?; } self.writer.flush()?; @@ -420,6 +432,11 @@ mod tests { Field::new("c4", DataType::Boolean, true), Field::new("c5", DataType::Timestamp(TimeUnit::Millisecond, None), true), Field::new("c6", DataType::Time32(TimeUnit::Second), false), + Field::new( + "c7", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ), ]); let c1 = StringArray::from(vec![ @@ -439,6 +456,8 @@ mod tests { None, ); let c6 = Time32SecondArray::from(vec![1234, 24680, 85563]); + let c7: DictionaryArray = + vec!["cupcakes", "cupcakes", "foo"].into_iter().collect(); let batch = RecordBatch::try_new( Arc::new(schema), @@ -449,6 +468,7 @@ mod tests { Arc::new(c4), Arc::new(c5), Arc::new(c6), + Arc::new(c7), ], ) .unwrap(); @@ -466,13 +486,13 @@ mod tests { file.read_to_end(&mut buffer).unwrap(); assert_eq!( - r#"c1,c2,c3,c4,c5,c6 -Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03 -Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34 -consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20 -sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03 + r#"c1,c2,c3,c4,c5,c6,c7 +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo +Lorem ipsum dolor sit amet,123.564532,3,true,,00:20:34,cupcakes +consectetur adipiscing elit,,2,false,2019-04-18T10:54:47.378000000,06:51:20,cupcakes +sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555000000,23:46:03,foo "# .to_string(), String::from_utf8(buffer).unwrap() diff --git a/arrow/src/json/writer.rs b/arrow/src/json/writer.rs index 27c1ff138aa1..8587c1d4ac6a 100644 --- a/arrow/src/json/writer.rs +++ b/arrow/src/json/writer.rs @@ -480,6 +480,12 @@ fn set_column_for_json_rows( } }); } + DataType::Dictionary(_, value_type) => { + let slice = array.slice(0, row_count); + let hydrated = crate::compute::kernels::cast::cast(&slice, &value_type) + .expect("cannot cast dictionary to underlying values"); + set_column_for_json_rows(rows, row_count, &hydrated, col_name) + } _ => { panic!("Unsupported datatype: {:#?}", array.data_type()); } @@ -681,8 +687,8 @@ mod tests { #[test] fn write_simple_rows() { let schema = Schema::new(vec![ - Field::new("c1", DataType::Int32, false), - Field::new("c2", DataType::Utf8, false), + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), ]); let a = Int32Array::from(vec![Some(1), Some(2), Some(3), None, Some(5)]); @@ -709,6 +715,56 @@ mod tests { ); } + #[test] + fn write_dictionary() { + let schema = Schema::new(vec![ + Field::new( + "c1", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + true, + ), + Field::new( + "c2", + DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)), + true, + ), + ]); + + let a: DictionaryArray = vec![ + Some("cupcakes"), + Some("foo"), + Some("foo"), + None, + Some("cupcakes"), + ] + .into_iter() + .collect(); + let b: DictionaryArray = + vec![Some("sdsd"), Some("sdsd"), None, Some("sd"), Some("sdsd")] + .into_iter() + .collect(); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]) + .unwrap(); + + let mut buf = Vec::new(); + { + let mut writer = LineDelimitedWriter::new(&mut buf); + writer.write_batches(&[batch]).unwrap(); + } + + assert_eq!( + String::from_utf8(buf).unwrap(), + r#"{"c1":"cupcakes","c2":"sdsd"} +{"c1":"foo","c2":"sdsd"} +{"c1":"foo"} +{"c2":"sd"} +{"c1":"cupcakes","c2":"sdsd"} +"# + ); + } + #[test] fn write_timestamps() { let ts_string = "2018-11-13T17:11:10.011375885995";