Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
312 changes: 210 additions & 102 deletions arrow-avro/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::schema::{
Attributes, AvroSchema, ComplexType, PrimitiveType, Record, Schema, Type, TypeName,
Attributes, AvroSchema, ComplexType, Enum, PrimitiveType, Record, Schema, Type, TypeName,
AVRO_ENUM_SYMBOLS_METADATA_KEY,
};
use arrow_schema::{
Expand Down Expand Up @@ -48,7 +48,7 @@ pub(crate) enum ResolutionInfo {
Promotion(Promotion),
/// Indicates that a default value should be used for a field. (Implemented in a Follow-up PR)
DefaultValue(AvroLiteral),
/// Provides mapping information for resolving enums. (Implemented in a Follow-up PR)
/// Provides mapping information for resolving enums.
EnumMapping(EnumMapping),
/// Provides resolution information for record fields. (Implemented in a Follow-up PR)
Record(ResolvedRecord),
Expand Down Expand Up @@ -587,6 +587,63 @@ impl<'a> Resolver<'a> {
}
}

fn names_match(
writer_name: &str,
writer_aliases: &[&str],
reader_name: &str,
reader_aliases: &[&str],
) -> bool {
writer_name == reader_name
|| reader_aliases.contains(&writer_name)
|| writer_aliases.contains(&reader_name)
}

fn ensure_names_match(
data_type: &str,
writer_name: &str,
writer_aliases: &[&str],
reader_name: &str,
reader_aliases: &[&str],
) -> Result<(), ArrowError> {
if names_match(writer_name, writer_aliases, reader_name, reader_aliases) {
Ok(())
} else {
Err(ArrowError::ParseError(format!(
"{data_type} name mismatch writer={writer_name}, reader={reader_name}"
)))
}
}

fn primitive_of(schema: &Schema) -> Option<PrimitiveType> {
match schema {
Schema::TypeName(TypeName::Primitive(primitive)) => Some(*primitive),
Schema::Type(Type {
r#type: TypeName::Primitive(primitive),
..
}) => Some(*primitive),
_ => None,
}
}

fn nullable_union_variants<'x, 'y>(
variant: &'y [Schema<'x>],
) -> Option<(Nullability, &'y Schema<'x>)> {
if variant.len() != 2 {
return None;
}
let is_null = |schema: &Schema<'x>| {
matches!(
schema,
Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))
)
};
match (is_null(&variant[0]), is_null(&variant[1])) {
(true, false) => Some((Nullability::NullFirst, &variant[1])),
(false, true) => Some((Nullability::NullSecond, &variant[0])),
_ => None,
}
}

/// Resolves Avro type names to [`AvroDataType`]
///
/// See <https://avro.apache.org/docs/1.11.1/specification/#names>
Expand Down Expand Up @@ -815,77 +872,36 @@ impl<'a> Maker<'a> {
reader_schema: &'s Schema<'a>,
namespace: Option<&'a str>,
) -> Result<AvroDataType, ArrowError> {
if let (Some(write_primitive), Some(read_primitive)) =
(primitive_of(writer_schema), primitive_of(reader_schema))
{
return self.resolve_primitives(write_primitive, read_primitive, reader_schema);
}
match (writer_schema, reader_schema) {
(
Schema::TypeName(TypeName::Primitive(writer_primitive)),
Schema::TypeName(TypeName::Primitive(reader_primitive)),
) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema),
(
Schema::Type(Type {
r#type: TypeName::Primitive(writer_primitive),
..
}),
Schema::Type(Type {
r#type: TypeName::Primitive(reader_primitive),
..
}),
) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema),
(
Schema::TypeName(TypeName::Primitive(writer_primitive)),
Schema::Type(Type {
r#type: TypeName::Primitive(reader_primitive),
..
}),
) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema),
(
Schema::Type(Type {
r#type: TypeName::Primitive(writer_primitive),
..
}),
Schema::TypeName(TypeName::Primitive(reader_primitive)),
) => self.resolve_primitives(*writer_primitive, *reader_primitive, reader_schema),
(
Schema::Complex(ComplexType::Record(writer_record)),
Schema::Complex(ComplexType::Record(reader_record)),
) => self.resolve_records(writer_record, reader_record, namespace),
(Schema::Union(writer_variants), Schema::Union(reader_variants)) => {
self.resolve_nullable_union(writer_variants, reader_variants, namespace)
}
(
Schema::Complex(ComplexType::Enum(writer_enum)),
Schema::Complex(ComplexType::Enum(reader_enum)),
) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace),
(Schema::Union(writer_variants), Schema::Union(reader_variants)) => self
.resolve_nullable_union(
writer_variants.as_slice(),
reader_variants.as_slice(),
namespace,
),
(Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace),
(_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace),
// if both sides are the same complex kind (non-record), adopt the reader type.
// This aligns with Avro spec: arrays, maps, and enums resolve recursively;
// for identical shapes we can just parse the reader schema.
(Schema::Complex(ComplexType::Array(_)), Schema::Complex(ComplexType::Array(_)))
| (Schema::Complex(ComplexType::Map(_)), Schema::Complex(ComplexType::Map(_)))
| (Schema::Complex(ComplexType::Fixed(_)), Schema::Complex(ComplexType::Fixed(_)))
| (Schema::Complex(ComplexType::Enum(_)), Schema::Complex(ComplexType::Enum(_))) => {
| (Schema::Complex(ComplexType::Fixed(_)), Schema::Complex(ComplexType::Fixed(_))) => {
self.parse_type(reader_schema, namespace)
}
// Named-type references (equal on both sides) – parse reader side.
(Schema::TypeName(TypeName::Ref(_)), Schema::TypeName(TypeName::Ref(_)))
| (
Schema::Type(Type {
r#type: TypeName::Ref(_),
..
}),
Schema::Type(Type {
r#type: TypeName::Ref(_),
..
}),
)
| (
Schema::TypeName(TypeName::Ref(_)),
Schema::Type(Type {
r#type: TypeName::Ref(_),
..
}),
)
| (
Schema::Type(Type {
r#type: TypeName::Ref(_),
..
}),
Schema::TypeName(TypeName::Ref(_)),
) => self.parse_type(reader_schema, namespace),
_ => Err(ArrowError::NotYetImplemented(
"Other resolutions not yet implemented".to_string(),
)),
Expand Down Expand Up @@ -921,64 +937,156 @@ impl<'a> Maker<'a> {
Ok(datatype)
}

fn resolve_nullable_union(
fn resolve_nullable_union<'s>(
&mut self,
writer_variants: &[Schema<'a>],
reader_variants: &[Schema<'a>],
writer_variants: &'s [Schema<'a>],
reader_variants: &'s [Schema<'a>],
namespace: Option<&'a str>,
) -> Result<AvroDataType, ArrowError> {
// Only support unions with exactly two branches, one of which is `null` on both sides
if writer_variants.len() != 2 || reader_variants.len() != 2 {
return Err(ArrowError::NotYetImplemented(
"Only 2-branch unions are supported for schema resolution".to_string(),
));
}
let is_null = |s: &Schema<'a>| {
matches!(
s,
Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))
)
};
let w_null_pos = writer_variants.iter().position(is_null);
let r_null_pos = reader_variants.iter().position(is_null);
match (w_null_pos, r_null_pos) {
(Some(wp), Some(rp)) => {
// Extract a non-null branch on each side
let w_nonnull = &writer_variants[1 - wp];
let r_nonnull = &reader_variants[1 - rp];
// Resolve the non-null branch
let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?;
match (
nullable_union_variants(writer_variants),
nullable_union_variants(reader_variants),
) {
(Some((_, write_nonnull)), Some((read_nb, read_nonnull))) => {
let mut dt = self.make_data_type(write_nonnull, Some(read_nonnull), namespace)?;
// Adopt reader union null ordering
dt.nullability = Some(match rp {
0 => Nullability::NullFirst,
1 => Nullability::NullSecond,
_ => unreachable!(),
});
dt.nullability = Some(read_nb);
Ok(dt)
}
_ => Err(ArrowError::NotYetImplemented(
"Union resolution requires both writer and reader to be nullable unions"
"Union resolution requires both writer and reader to be 2-branch nullable unions"
.to_string(),
)),
}
}

// Resolve writer vs. reader enum schemas according to Avro 1.11.1.
//
// # How enums resolve (writer to reader)
// Per “Schema Resolution”:
// * The two schemas must refer to the same (unqualified) enum name (or match
// via alias rewriting).
// * If the writer’s symbol is not present in the reader’s enum and the reader
// enum has a `default`, that `default` symbol must be used; otherwise,
// error.
// https://avro.apache.org/docs/1.11.1/specification/#schema-resolution
// * Avro “Aliases” are applied from the reader side to rewrite the writer’s
// names during resolution. For robustness across ecosystems, we also accept
// symmetry here (see note below).
// https://avro.apache.org/docs/1.11.1/specification/#aliases
//
// # Rationale for this code path
// 1. Do the work once at schema‑resolution time. Avro serializes an enum as a
// writer‑side position. Mapping positions on the hot decoder path is expensive
// if done with string lookups. This method builds a `writer_index to reader_index`
// vector once, so decoding just does an O(1) table lookup.
// 2. Adopt the reader’s symbol set and order. We return an Arrow
// `Dictionary(Int32, Utf8)` whose dictionary values are the reader enum
// symbols. This makes downstream semantics match the reader schema, including
// Avro’s sort order rule that orders enums by symbol position in the schema.
// https://avro.apache.org/docs/1.11.1/specification/#sort-order
// 3. Honor Avro’s `default` for enums. Avro 1.9+ allows a type‑level default
// on the enum. When the writer emits a symbol unknown to the reader, we map it
// to the reader’s validated `default` symbol if present; otherwise we signal an
// error at decoding time.
// https://avro.apache.org/docs/1.11.1/specification/#enums
//
// # Implementation notes
// * We first check that enum names match or are*alias‑equivalent. The Avro
// spec describes alias rewriting using reader aliases; this implementation
// additionally treats writer aliases as acceptable for name matching to be
// resilient with schemas produced by different tooling.
// * We build `EnumMapping`:
// - `mapping[i]` = reader index of the writer symbol at writer index `i`.
// - If the writer symbol is absent and the reader has a default, we store the
// reader index of that default.
// - Otherwise we store `-1` as a sentinel meaning unresolvable; the decoder
// must treat encountering such a value as an error, per the spec.
// * We persist the reader symbol list in field metadata under
// `AVRO_ENUM_SYMBOLS_METADATA_KEY`, so consumers can inspect the dictionary
// without needing the original Avro schema.
// * The Arrow representation is `Dictionary(Int32, Utf8)`, which aligns with
// Avro’s integer index encoding for enums.
//
// # Examples
// * Writer `["A","B","C"]`, Reader `["A","B"]`, Reader default `"A"`
// `mapping = [0, 1, 0]`, `default_index = 0`.
// * Writer `["A","B"]`, Reader `["B","A"]` (no default)
// `mapping = [1, 0]`, `default_index = -1`.
// * Writer `["A","B","C"]`, Reader `["A","B"]` (no default)
// `mapping = [0, 1, -1]` (decode must error on `"C"`).
fn resolve_enums(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some more context about Avro enums here for future readers / maintainers?

I think more or less copying the contents of this PR's description in "Rationale for this change" is probably good enough.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alamb I just pushed up a detailed comment on enums with links and examples. Let me know what you think!

&mut self,
writer_enum: &Enum<'a>,
reader_enum: &Enum<'a>,
reader_schema: &Schema<'a>,
namespace: Option<&'a str>,
) -> Result<AvroDataType, ArrowError> {
ensure_names_match(
"Enum",
writer_enum.name,
&writer_enum.aliases,
reader_enum.name,
&reader_enum.aliases,
)?;
if writer_enum.symbols == reader_enum.symbols {
return self.parse_type(reader_schema, namespace);
}
let reader_index: HashMap<&str, i32> = reader_enum
.symbols
.iter()
.enumerate()
.map(|(index, &symbol)| (symbol, index as i32))
.collect();
let default_index: i32 = match reader_enum.default {
Some(symbol) => *reader_index.get(symbol).ok_or_else(|| {
ArrowError::SchemaError(format!(
"Reader enum '{}' default symbol '{symbol}' not found in symbols list",
reader_enum.name,
))
})?,
None => -1,
};
let mapping: Vec<i32> = writer_enum
.symbols
.iter()
.map(|&write_symbol| {
reader_index
.get(write_symbol)
.copied()
.unwrap_or(default_index)
})
.collect();
if self.strict_mode && mapping.iter().any(|&m| m < 0) {
return Err(ArrowError::SchemaError(format!(
"Reader enum '{}' does not cover all writer symbols and no default is provided",
reader_enum.name
)));
}
let mut dt = self.parse_type(reader_schema, namespace)?;
dt.resolution = Some(ResolutionInfo::EnumMapping(EnumMapping {
mapping: Arc::from(mapping),
default_index,
}));
let reader_ns = reader_enum.namespace.or(namespace);
self.resolver
.register(reader_enum.name, reader_ns, dt.clone());
Ok(dt)
}

fn resolve_records(
&mut self,
writer_record: &Record<'a>,
reader_record: &Record<'a>,
namespace: Option<&'a str>,
) -> Result<AvroDataType, ArrowError> {
// Names must match or be aliased
let names_match = writer_record.name == reader_record.name
|| reader_record.aliases.contains(&writer_record.name)
|| writer_record.aliases.contains(&reader_record.name);
if !names_match {
return Err(ArrowError::ParseError(format!(
"Record name mismatch writer={}, reader={}",
writer_record.name, reader_record.name
)));
}
ensure_names_match(
"Record",
writer_record.name,
&writer_record.aliases,
reader_record.name,
&reader_record.aliases,
)?;
let writer_ns = writer_record.namespace.or(namespace);
let reader_ns = reader_record.namespace.or(namespace);
// Map writer field name -> index
Expand All @@ -995,7 +1103,7 @@ impl<'a> Maker<'a> {
// Build reader fields and mapping
for (reader_idx, r_field) in reader_record.fields.iter().enumerate() {
if let Some(&writer_idx) = writer_index_map.get(r_field.name) {
// Field exists in writer: resolve types (including promotions and union-of-null)
// Field exists in a writer: resolve types (including promotions and union-of-null)
let w_schema = &writer_record.fields[writer_idx].r#type;
let resolved_dt =
self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?;
Expand Down
Loading
Loading