@@ -25,6 +25,7 @@ use std::sync::Arc;
2525
2626use  crate :: array:: * ; 
2727use  crate :: datatypes:: * ; 
28+ use  crate :: error:: { ArrowError ,  Result } ; 
2829
2930/// A batch of column-oriented data 
3031#[ derive( Clone ) ]  
@@ -34,36 +35,61 @@ pub struct RecordBatch {
3435} 
3536
3637impl  RecordBatch  { 
37-     pub  fn  new ( schema :  Arc < Schema > ,  columns :  Vec < ArrayRef > )  -> Self  { 
38-         // assert that there are some columns 
39-         assert ! ( 
40-             columns. len( )  > 0 , 
41-             "at least one column must be defined to create a record batch" 
42-         ) ; 
43-         // assert that all columns have the same row count 
38+     /// Creates a `RecordBatch` from a schema and columns 
39+      /// 
40+      /// Expects the following: 
41+      ///  * the vec of columns to not be empty 
42+      ///  * the schema and column data types to have equal lengths and match 
43+      ///  * each array in columns to have the same length 
44+      pub  fn  try_new ( schema :  Arc < Schema > ,  columns :  Vec < ArrayRef > )  -> Result < Self >  { 
45+         // check that there are some columns 
46+         if  columns. is_empty ( )  { 
47+             return  Err ( ArrowError :: InvalidArgumentError ( 
48+                 "at least one column must be defined to create a record batch" 
49+                     . to_string ( ) , 
50+             ) ) ; 
51+         } 
52+         // check that number of fields in schema match column length 
53+         if  schema. fields ( ) . len ( )  != columns. len ( )  { 
54+             return  Err ( ArrowError :: InvalidArgumentError ( 
55+                 "number of columns must match number of fields in schema" . to_string ( ) , 
56+             ) ) ; 
57+         } 
58+         // check that all columns have the same row count, and match the schema 
4459        let  len = columns[ 0 ] . data ( ) . len ( ) ; 
45-         for  i in  1 ..columns. len ( )  { 
46-             assert_eq ! ( 
47-                 len, 
48-                 columns[ i] . len( ) , 
49-                 "all columns in a record batch must have the same length" 
50-             ) ; 
60+         for  i in  0 ..columns. len ( )  { 
61+             if  columns[ i] . len ( )  != len { 
62+                 return  Err ( ArrowError :: InvalidArgumentError ( 
63+                     "all columns in a record batch must have the same length" . to_string ( ) , 
64+                 ) ) ; 
65+             } 
66+             if  columns[ i] . data_type ( )  != schema. field ( i) . data_type ( )  { 
67+                 return  Err ( ArrowError :: InvalidArgumentError ( format ! ( 
68+                     "column types must match schema types, expected {:?} but found {:?} at column index {}" ,  
69+                     schema. field( i) . data_type( ) , 
70+                     columns[ i] . data_type( ) , 
71+                     i) ) ) ; 
72+             } 
5173        } 
52-         RecordBatch  {  schema,  columns } 
74+         Ok ( RecordBatch  {  schema,  columns } ) 
5375    } 
5476
77+     /// Returns the schema of the record batch 
5578     pub  fn  schema ( & self )  -> & Arc < Schema >  { 
5679        & self . schema 
5780    } 
5881
82+     /// Number of columns in the record batch 
5983     pub  fn  num_columns ( & self )  -> usize  { 
6084        self . columns . len ( ) 
6185    } 
6286
87+     /// Number of rows in each column 
6388     pub  fn  num_rows ( & self )  -> usize  { 
6489        self . columns [ 0 ] . data ( ) . len ( ) 
6590    } 
6691
92+     /// Get a reference to a column's array by index 
6793     pub  fn  column ( & self ,  i :  usize )  -> & ArrayRef  { 
6894        & self . columns [ i] 
6995    } 
@@ -103,7 +129,8 @@ mod tests {
103129        let  b = BinaryArray :: from ( array_data) ; 
104130
105131        let  record_batch =
106-             RecordBatch :: new ( Arc :: new ( schema) ,  vec ! [ Arc :: new( a) ,  Arc :: new( b) ] ) ; 
132+             RecordBatch :: try_new ( Arc :: new ( schema) ,  vec ! [ Arc :: new( a) ,  Arc :: new( b) ] ) 
133+                 . unwrap ( ) ; 
107134
108135        assert_eq ! ( 5 ,  record_batch. num_rows( ) ) ; 
109136        assert_eq ! ( 2 ,  record_batch. num_columns( ) ) ; 
@@ -112,4 +139,26 @@ mod tests {
112139        assert_eq ! ( 5 ,  record_batch. column( 0 ) . data( ) . len( ) ) ; 
113140        assert_eq ! ( 5 ,  record_batch. column( 1 ) . data( ) . len( ) ) ; 
114141    } 
142+ 
143+     #[ test]  
144+     fn  create_record_batch_schema_mismatch ( )  { 
145+         let  schema = Schema :: new ( vec ! [ Field :: new( "a" ,  DataType :: Int32 ,  false ) ] ) ; 
146+ 
147+         let  a = Int64Array :: from ( vec ! [ 1 ,  2 ,  3 ,  4 ,  5 ] ) ; 
148+ 
149+         let  batch = RecordBatch :: try_new ( Arc :: new ( schema) ,  vec ! [ Arc :: new( a) ] ) ; 
150+         assert ! ( !batch. is_ok( ) ) ; 
151+     } 
152+ 
153+     #[ test]  
154+     fn  create_record_batch_record_mismatch ( )  { 
155+         let  schema = Schema :: new ( vec ! [ Field :: new( "a" ,  DataType :: Int32 ,  false ) ] ) ; 
156+ 
157+         let  a = Int32Array :: from ( vec ! [ 1 ,  2 ,  3 ,  4 ,  5 ] ) ; 
158+         let  b = Int32Array :: from ( vec ! [ 1 ,  2 ,  3 ,  4 ,  5 ] ) ; 
159+ 
160+         let  batch =
161+             RecordBatch :: try_new ( Arc :: new ( schema) ,  vec ! [ Arc :: new( a) ,  Arc :: new( b) ] ) ; 
162+         assert ! ( !batch. is_ok( ) ) ; 
163+     } 
115164} 
0 commit comments