Skip to content

Commit faeb309

Browse files
nevi-mesunchao
authored andcommitted
ARROW-4749: [Rust] Return Result for RecordBatch::new()
Adds more validation between schemas and columns, returning an error when record types mismatch the schema Author: Neville Dipale <[email protected]> Author: Andy Grove <[email protected]> Closes #3800 from nevi-me/ARROW-4749 and squashes the following commits: 586395d <Neville Dipale> RecordBatch::try -> RecordBatch::try_new fbdde0f <Neville Dipale> fix csv writer tests d92aeb6 <Andy Grove> fix aggr schema 5dd6bd0 <Neville Dipale> rebase against master, update record batch in in-memory source 0c402f5 <Neville Dipale> ARROW-4749: Return Result for RecordBatch::new()
1 parent dd3ba0c commit faeb309

File tree

13 files changed

+139
-57
lines changed

13 files changed

+139
-57
lines changed

rust/arrow/benches/csv_writer.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ fn record_batches_to_csv() {
4848
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
4949
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
5050

51-
let b = RecordBatch::new(
51+
let b = RecordBatch::try_new(
5252
Arc::new(schema),
5353
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
5454
);

rust/arrow/examples/dynamic_types.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ extern crate arrow;
2222

2323
use arrow::array::*;
2424
use arrow::datatypes::*;
25+
use arrow::error::Result;
2526
use arrow::record_batch::*;
2627

27-
fn main() {
28+
fn main() -> Result<()> {
2829
// define schema
2930
let schema = Schema::new(vec![
3031
Field::new("id", DataType::Int32, false),
@@ -58,9 +59,10 @@ fn main() {
5859
]);
5960

6061
// build a record batch
61-
let batch = RecordBatch::new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)]);
62+
let batch =
63+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(id), Arc::new(nested)])?;
6264

63-
process(&batch);
65+
Ok(process(&batch))
6466
}
6567

6668
/// Create a new batch by performing a projection of id, nested.c
@@ -88,7 +90,7 @@ fn process(batch: &RecordBatch) {
8890
Field::new("sum", DataType::Float64, false),
8991
]);
9092

91-
let _ = RecordBatch::new(
93+
let _ = RecordBatch::try_new(
9294
Arc::new(projected_schema),
9395
vec![
9496
id.clone(), // NOTE: this is cloning the Arc not the array data

rust/arrow/src/csv/reader.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ impl<R: Read> Reader<R> {
329329
let projected_schema = Arc::new(Schema::new(projected_fields));
330330

331331
match arrays {
332-
Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
332+
Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
333+
Ok(batch) => Ok(Some(batch)),
334+
Err(e) => Err(e),
335+
},
333336
Err(e) => Err(e),
334337
}
335338
}

rust/arrow/src/csv/writer.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@
5050
//! let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
5151
//! let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
5252
//!
53-
//! let batch = RecordBatch::new(
53+
//! let batch = RecordBatch::try_new(
5454
//! Arc::new(schema),
5555
//! vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
56-
//! );
56+
//! ).unwrap();
5757
//!
5858
//! let file = get_temp_file("out.csv", &[]);
5959
//!
@@ -287,10 +287,11 @@ mod tests {
287287
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
288288
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
289289

290-
let batch = RecordBatch::new(
290+
let batch = RecordBatch::try_new(
291291
Arc::new(schema),
292292
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
293-
);
293+
)
294+
.unwrap();
294295

295296
let file = get_temp_file("columns.csv", &[]);
296297

@@ -331,10 +332,11 @@ mod tests {
331332
let c3 = PrimitiveArray::<UInt32Type>::from(vec![3, 2, 1]);
332333
let c4 = PrimitiveArray::<BooleanType>::from(vec![Some(true), Some(false), None]);
333334

334-
let batch = RecordBatch::new(
335+
let batch = RecordBatch::try_new(
335336
Arc::new(schema),
336337
vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)],
337-
);
338+
)
339+
.unwrap();
338340

339341
let file = get_temp_file("custom_options.csv", &[]);
340342

rust/arrow/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub enum ArrowError {
3030
CsvError(String),
3131
JsonError(String),
3232
IoError(String),
33+
InvalidArgumentError(String),
3334
}
3435

3536
impl From<::std::io::Error> for ArrowError {

rust/arrow/src/json/reader.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,10 @@ impl<R: Read> Reader<R> {
487487
let projected_schema = Arc::new(Schema::new(projected_fields));
488488

489489
match arrays {
490-
Ok(arr) => Ok(Some(RecordBatch::new(projected_schema, arr))),
490+
Ok(arr) => match RecordBatch::try_new(projected_schema, arr) {
491+
Ok(batch) => Ok(Some(batch)),
492+
Err(e) => Err(e),
493+
},
491494
Err(e) => Err(e),
492495
}
493496
}

rust/arrow/src/record_batch.rs

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use std::sync::Arc;
2525

2626
use crate::array::*;
2727
use crate::datatypes::*;
28+
use crate::error::{ArrowError, Result};
2829

2930
/// A batch of column-oriented data
3031
#[derive(Clone)]
@@ -34,36 +35,61 @@ pub struct RecordBatch {
3435
}
3536

3637
impl RecordBatch {
37-
pub fn new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Self {
38-
// assert that there are some columns
39-
assert!(
40-
columns.len() > 0,
41-
"at least one column must be defined to create a record batch"
42-
);
43-
// assert that all columns have the same row count
38+
/// Creates a `RecordBatch` from a schema and columns
39+
///
40+
/// Expects the following:
41+
/// * the vec of columns to not be empty
42+
/// * the schema and column data types to have equal lengths and match
43+
/// * each array in columns to have the same length
44+
pub fn try_new(schema: Arc<Schema>, columns: Vec<ArrayRef>) -> Result<Self> {
45+
// check that there are some columns
46+
if columns.is_empty() {
47+
return Err(ArrowError::InvalidArgumentError(
48+
"at least one column must be defined to create a record batch"
49+
.to_string(),
50+
));
51+
}
52+
// check that number of fields in schema match column length
53+
if schema.fields().len() != columns.len() {
54+
return Err(ArrowError::InvalidArgumentError(
55+
"number of columns must match number of fields in schema".to_string(),
56+
));
57+
}
58+
// check that all columns have the same row count, and match the schema
4459
let len = columns[0].data().len();
45-
for i in 1..columns.len() {
46-
assert_eq!(
47-
len,
48-
columns[i].len(),
49-
"all columns in a record batch must have the same length"
50-
);
60+
for i in 0..columns.len() {
61+
if columns[i].len() != len {
62+
return Err(ArrowError::InvalidArgumentError(
63+
"all columns in a record batch must have the same length".to_string(),
64+
));
65+
}
66+
if columns[i].data_type() != schema.field(i).data_type() {
67+
return Err(ArrowError::InvalidArgumentError(format!(
68+
"column types must match schema types, expected {:?} but found {:?} at column index {}",
69+
schema.field(i).data_type(),
70+
columns[i].data_type(),
71+
i)));
72+
}
5173
}
52-
RecordBatch { schema, columns }
74+
Ok(RecordBatch { schema, columns })
5375
}
5476

77+
/// Returns the schema of the record batch
5578
pub fn schema(&self) -> &Arc<Schema> {
5679
&self.schema
5780
}
5881

82+
/// Number of columns in the record batch
5983
pub fn num_columns(&self) -> usize {
6084
self.columns.len()
6185
}
6286

87+
/// Number of rows in each column
6388
pub fn num_rows(&self) -> usize {
6489
self.columns[0].data().len()
6590
}
6691

92+
/// Get a reference to a column's array by index
6793
pub fn column(&self, i: usize) -> &ArrayRef {
6894
&self.columns[i]
6995
}
@@ -103,7 +129,8 @@ mod tests {
103129
let b = BinaryArray::from(array_data);
104130

105131
let record_batch =
106-
RecordBatch::new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
132+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)])
133+
.unwrap();
107134

108135
assert_eq!(5, record_batch.num_rows());
109136
assert_eq!(2, record_batch.num_columns());
@@ -112,4 +139,26 @@ mod tests {
112139
assert_eq!(5, record_batch.column(0).data().len());
113140
assert_eq!(5, record_batch.column(1).data().len());
114141
}
142+
143+
#[test]
144+
fn create_record_batch_schema_mismatch() {
145+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
146+
147+
let a = Int64Array::from(vec![1, 2, 3, 4, 5]);
148+
149+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]);
150+
assert!(!batch.is_ok());
151+
}
152+
153+
#[test]
154+
fn create_record_batch_record_mismatch() {
155+
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);
156+
157+
let a = Int32Array::from(vec![1, 2, 3, 4, 5]);
158+
let b = Int32Array::from(vec![1, 2, 3, 4, 5]);
159+
160+
let batch =
161+
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a), Arc::new(b)]);
162+
assert!(!batch.is_ok());
163+
}
115164
}

rust/datafusion/src/datasource/memory.rs

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -102,20 +102,25 @@ impl Table for MemTable {
102102

103103
let projected_schema = Arc::new(Schema::new(projected_columns?));
104104

105-
Ok(Rc::new(RefCell::new(MemBatchIterator {
106-
schema: projected_schema.clone(),
107-
index: 0,
108-
batches: self
109-
.batches
110-
.iter()
111-
.map(|batch| {
112-
RecordBatch::new(
113-
projected_schema.clone(),
114-
columns.iter().map(|i| batch.column(*i).clone()).collect(),
115-
)
116-
})
117-
.collect(),
118-
})))
105+
let batches = self
106+
.batches
107+
.iter()
108+
.map(|batch| {
109+
RecordBatch::try_new(
110+
projected_schema.clone(),
111+
columns.iter().map(|i| batch.column(*i).clone()).collect(),
112+
)
113+
})
114+
.collect();
115+
116+
match batches {
117+
Ok(batches) => Ok(Rc::new(RefCell::new(MemBatchIterator {
118+
schema: projected_schema.clone(),
119+
index: 0,
120+
batches,
121+
}))),
122+
Err(e) => Err(ExecutionError::ArrowError(e)),
123+
}
119124
}
120125
}
121126

@@ -155,14 +160,15 @@ mod tests {
155160
Field::new("c", DataType::Int32, false),
156161
]));
157162

158-
let batch = RecordBatch::new(
163+
let batch = RecordBatch::try_new(
159164
schema.clone(),
160165
vec![
161166
Arc::new(Int32Array::from(vec![1, 2, 3])),
162167
Arc::new(Int32Array::from(vec![4, 5, 6])),
163168
Arc::new(Int32Array::from(vec![7, 8, 9])),
164169
],
165-
);
170+
)
171+
.unwrap();
166172

167173
let provider = MemTable::new(schema, vec![batch]).unwrap();
168174

@@ -183,14 +189,15 @@ mod tests {
183189
Field::new("c", DataType::Int32, false),
184190
]));
185191

186-
let batch = RecordBatch::new(
192+
let batch = RecordBatch::try_new(
187193
schema.clone(),
188194
vec![
189195
Arc::new(Int32Array::from(vec![1, 2, 3])),
190196
Arc::new(Int32Array::from(vec![4, 5, 6])),
191197
Arc::new(Int32Array::from(vec![7, 8, 9])),
192198
],
193-
);
199+
)
200+
.unwrap();
194201

195202
let provider = MemTable::new(schema, vec![batch]).unwrap();
196203

@@ -208,14 +215,15 @@ mod tests {
208215
Field::new("c", DataType::Int32, false),
209216
]));
210217

211-
let batch = RecordBatch::new(
218+
let batch = RecordBatch::try_new(
212219
schema.clone(),
213220
vec![
214221
Arc::new(Int32Array::from(vec![1, 2, 3])),
215222
Arc::new(Int32Array::from(vec![4, 5, 6])),
216223
Arc::new(Int32Array::from(vec![7, 8, 9])),
217224
],
218-
);
225+
)
226+
.unwrap();
219227

220228
let provider = MemTable::new(schema, vec![batch]).unwrap();
221229

@@ -243,14 +251,15 @@ mod tests {
243251
Field::new("c", DataType::Int32, false),
244252
]));
245253

246-
let batch = RecordBatch::new(
254+
let batch = RecordBatch::try_new(
247255
schema1.clone(),
248256
vec![
249257
Arc::new(Int32Array::from(vec![1, 2, 3])),
250258
Arc::new(Int32Array::from(vec![4, 5, 6])),
251259
Arc::new(Int32Array::from(vec![7, 8, 9])),
252260
],
253-
);
261+
)
262+
.unwrap();
254263

255264
match MemTable::new(schema2, vec![batch]) {
256265
Err(ExecutionError::General(e)) => assert_eq!(

rust/datafusion/src/execution/aggregate.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,10 @@ impl AggregateRelation {
800800
}
801801
}
802802

803-
Ok(Some(RecordBatch::new(self.schema.clone(), result_columns)))
803+
Ok(Some(RecordBatch::try_new(
804+
self.schema.clone(),
805+
result_columns,
806+
)?))
804807
}
805808

806809
fn with_group_by(&mut self) -> Result<Option<RecordBatch>> {
@@ -1008,7 +1011,10 @@ impl AggregateRelation {
10081011
result_arrays.push(array?);
10091012
}
10101013

1011-
Ok(Some(RecordBatch::new(self.schema.clone(), result_arrays)))
1014+
Ok(Some(RecordBatch::try_new(
1015+
self.schema.clone(),
1016+
result_arrays,
1017+
)?))
10121018
}
10131019
}
10141020

@@ -1136,7 +1142,7 @@ mod tests {
11361142
.unwrap();
11371143

11381144
let aggr_schema = Arc::new(Schema::new(vec![
1139-
Field::new("c2", DataType::Int32, false),
1145+
Field::new("c2", DataType::UInt32, false),
11401146
Field::new("min", DataType::Float64, false),
11411147
Field::new("max", DataType::Float64, false),
11421148
Field::new("sum", DataType::Float64, false),

rust/datafusion/src/execution/context.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,15 @@ impl ExecutionContext {
187187
.collect();
188188
let compiled_aggr_expr = compiled_aggr_expr_result?;
189189

190+
let mut output_fields: Vec<Field> = vec![];
191+
for expr in group_expr {
192+
output_fields.push(expr_to_field(expr, input_schema.as_ref()));
193+
}
194+
for expr in aggr_expr {
195+
output_fields.push(expr_to_field(expr, input_schema.as_ref()));
196+
}
190197
let rel = AggregateRelation::new(
191-
Arc::new(Schema::empty()), //(expr_to_field(&compiled_group_expr, &input_schema))),
198+
Arc::new(Schema::new(output_fields)),
192199
input_rel,
193200
compiled_group_expr,
194201
compiled_aggr_expr,

0 commit comments

Comments
 (0)