Skip to content
126 changes: 109 additions & 17 deletions arrow-avro/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField {
match schema {
Schema::Complex(ComplexType::Record(r)) => {
let mut resolver = Resolver::default();
let data_type = make_data_type(schema, None, &mut resolver, false)?;
let data_type = make_data_type(schema, None, &mut resolver, false, false)?;
Ok(AvroField {
data_type,
name: r.name.to_string(),
Expand All @@ -161,6 +161,60 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField {
}
}

/// Builder for an [`AvroField`]
#[derive(Debug)]
pub struct AvroFieldBuilder<'a> {
schema: &'a Schema<'a>,
use_utf8view: bool,
strict_mode: bool,
}

impl<'a> AvroFieldBuilder<'a> {
/// Creates a new [`AvroFieldBuilder`]
pub fn new(schema: &'a Schema<'a>) -> Self {
Self {
schema,
use_utf8view: false,
strict_mode: false,
}
}

/// Enable or disable Utf8View support
pub fn with_utf8view(mut self, use_utf8view: bool) -> Self {
self.use_utf8view = use_utf8view;
self
}

/// Enable or disable strict mode.
pub fn with_strict_mode(mut self, strict_mode: bool) -> Self {
self.strict_mode = strict_mode;
self
}

/// Build an [`AvroField`] from the builder
pub fn build(self) -> Result<AvroField, ArrowError> {
match self.schema {
Schema::Complex(ComplexType::Record(r)) => {
let mut resolver = Resolver::default();
let data_type = make_data_type(
self.schema,
None,
&mut resolver,
self.use_utf8view,
self.strict_mode,
)?;
Ok(AvroField {
name: r.name.to_string(),
data_type,
})
}
_ => Err(ArrowError::ParseError(format!(
"Expected a Record schema to build an AvroField, but got {:?}",
self.schema
))),
}
}
}
/// An Avro encoding
///
/// <https://avro.apache.org/docs/1.11.1/specification/#encodings>
Expand Down Expand Up @@ -409,6 +463,7 @@ fn make_data_type<'a>(
namespace: Option<&'a str>,
resolver: &mut Resolver<'a>,
use_utf8view: bool,
strict_mode: bool,
) -> Result<AvroDataType, ArrowError> {
match schema {
Schema::TypeName(TypeName::Primitive(p)) => {
Expand All @@ -428,12 +483,20 @@ fn make_data_type<'a>(
.position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)));
match (f.len() == 2, null) {
(true, Some(0)) => {
let mut field = make_data_type(&f[1], namespace, resolver, use_utf8view)?;
let mut field =
make_data_type(&f[1], namespace, resolver, use_utf8view, strict_mode)?;
field.nullability = Some(Nullability::NullFirst);
Ok(field)
}
(true, Some(1)) => {
let mut field = make_data_type(&f[0], namespace, resolver, use_utf8view)?;
if strict_mode {
return Err(ArrowError::SchemaError(
"Found Avro union of the form ['T','null'], which is disallowed in strict_mode"
.to_string(),
));
}
let mut field =
make_data_type(&f[0], namespace, resolver, use_utf8view, strict_mode)?;
field.nullability = Some(Nullability::NullSecond);
Ok(field)
}
Expand All @@ -456,6 +519,7 @@ fn make_data_type<'a>(
namespace,
resolver,
use_utf8view,
strict_mode,
)?,
})
})
Expand All @@ -469,8 +533,13 @@ fn make_data_type<'a>(
Ok(field)
}
ComplexType::Array(a) => {
let mut field =
make_data_type(a.items.as_ref(), namespace, resolver, use_utf8view)?;
let mut field = make_data_type(
a.items.as_ref(),
namespace,
resolver,
use_utf8view,
strict_mode,
)?;
Ok(AvroDataType {
nullability: None,
metadata: a.attributes.field_metadata(),
Expand Down Expand Up @@ -535,7 +604,8 @@ fn make_data_type<'a>(
Ok(field)
}
ComplexType::Map(m) => {
let val = make_data_type(&m.values, namespace, resolver, use_utf8view)?;
let val =
make_data_type(&m.values, namespace, resolver, use_utf8view, strict_mode)?;
Ok(AvroDataType {
nullability: None,
metadata: m.attributes.field_metadata(),
Expand All @@ -549,6 +619,7 @@ fn make_data_type<'a>(
namespace,
resolver,
use_utf8view,
strict_mode,
)?;

// https://avro.apache.org/docs/1.11.1/specification/#logical-types
Expand Down Expand Up @@ -630,7 +701,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Int, "date");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::Date32));
}
Expand All @@ -640,7 +711,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimeMillis));
}
Expand All @@ -650,7 +721,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimeMicros));
}
Expand All @@ -660,7 +731,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimestampMillis(true)));
}
Expand All @@ -670,7 +741,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimestampMicros(true)));
}
Expand All @@ -680,7 +751,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimestampMillis(false)));
}
Expand All @@ -690,7 +761,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::TimestampMicros(false)));
}
Expand Down Expand Up @@ -745,7 +816,7 @@ mod tests {
let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type");

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert_eq!(
result.metadata.get("logicalType"),
Expand All @@ -758,7 +829,7 @@ mod tests {
let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap();

assert!(matches!(result.codec, Codec::Utf8View));
}
Expand All @@ -768,7 +839,7 @@ mod tests {
let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String));

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false).unwrap();
let result = make_data_type(&schema, None, &mut resolver, false, false).unwrap();

assert!(matches!(result.codec, Codec::Utf8));
}
Expand Down Expand Up @@ -796,7 +867,7 @@ mod tests {
let schema = Schema::Complex(ComplexType::Record(record));

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, true).unwrap();
let result = make_data_type(&schema, None, &mut resolver, true, false).unwrap();

if let Codec::Struct(fields) = &result.codec {
let first_field_codec = &fields[0].data_type().codec;
Expand All @@ -805,4 +876,25 @@ mod tests {
panic!("Expected Struct codec");
}
}

#[test]
fn test_union_with_strict_mode() {
let schema = Schema::Union(vec![
Schema::TypeName(TypeName::Primitive(PrimitiveType::String)),
Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)),
]);

let mut resolver = Resolver::default();
let result = make_data_type(&schema, None, &mut resolver, false, true);

assert!(result.is_err());
match result {
Err(ArrowError::SchemaError(msg)) => {
assert!(msg.contains(
"Found Avro union of the form ['T','null'], which is disallowed in strict_mode"
));
}
_ => panic!("Expected SchemaError"),
}
}
}
Loading
Loading