diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 2ac1ad038bd7..681691ec8c22 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,11 +15,15 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use crate::schema::{ + Array, Attributes, ComplexType, Enum, Fixed, Map, PrimitiveType, Record, RecordField, Schema, + TypeName, +}; +use arrow_schema::DataType::*; use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, }; -use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; @@ -45,19 +49,75 @@ pub struct AvroDataType { } impl AvroDataType { + /// Create a new AvroDataType with the given parts. + pub fn new( + codec: Codec, + nullability: Option, + metadata: HashMap, + ) -> Self { + AvroDataType { + codec, + nullability, + metadata, + } + } + + /// Create a new AvroDataType from a `Codec`, with default (no) nullability and empty metadata. + pub fn from_codec(codec: Codec) -> Self { + Self::new(codec, None, Default::default()) + } + /// Returns an arrow [`Field`] with the given name pub fn field_with_name(&self, name: &str) -> Field { let d = self.codec.data_type(); Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) } + /// Return a reference to the inner `Codec`. pub fn codec(&self) -> &Codec { &self.codec } + /// Return the nullability for this Avro type, if any. pub fn nullability(&self) -> Option { self.nullability } + + /// Convert this `AvroDataType`, which encapsulates an Arrow data type (`codec`) + /// plus nullability and metadata, back into an Avro `Schema<'a>`. + /// + /// - If `metadata["namespace"]` is present, we'll store it in the resulting schema for named types + /// (record, enum, fixed). + pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { + let inner_schema = self.codec.to_avro_schema(name); + let schema_with_namespace = maybe_add_namespace(inner_schema, self); + // If the field is nullable in Arrow, wrap Avro schema in a union: ["null", ]. + if self.nullability.is_some() { + Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + schema_with_namespace, + ]) + } else { + schema_with_namespace + } + } +} + +/// If this is a named complex type (Record, Enum, Fixed), attach `namespace` +/// from `dt.metadata["namespace"]` if present. Otherwise, return as-is. +fn maybe_add_namespace<'a>(mut schema: Schema<'a>, dt: &'a AvroDataType) -> Schema<'a> { + if let Some(ns_str) = dt.metadata.get("namespace") { + if let Schema::Complex(ref mut c) = schema { + match c { + ComplexType::Record(r) => r.namespace = Some(ns_str), + ComplexType::Enum(e) => e.namespace = Some(ns_str), + ComplexType::Fixed(f) => f.namespace = Some(ns_str), + // Arrays and Maps do not have a namespace field, so do nothing + _ => {} + } + } + } + schema } /// A named [`AvroDataType`] @@ -78,6 +138,7 @@ impl AvroField { &self.data_type } + /// Returns the name of this field pub fn name(&self) -> &str { &self.name } @@ -127,34 +188,228 @@ pub enum Codec { List(Arc), Struct(Arc<[AvroField]>), Interval, + /// In Arrow, use Dictionary(Int32, Utf8) for Enum. + Enum(Vec), + Map(Arc), + Decimal(usize, Option, Option), } impl Codec { + /// Convert this to an Arrow `DataType` fn data_type(&self) -> DataType { match self { - Self::Null => DataType::Null, - Self::Boolean => DataType::Boolean, - Self::Int32 => DataType::Int32, - Self::Int64 => DataType::Int64, - Self::Float32 => DataType::Float32, - Self::Float64 => DataType::Float64, - Self::Binary => DataType::Binary, - Self::Utf8 => DataType::Utf8, - Self::Date32 => DataType::Date32, - Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), - Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::Null => Null, + Self::Boolean => Boolean, + Self::Int32 => Int32, + Self::Int64 => Int64, + Self::Float32 => Float32, + Self::Float64 => Float64, + Self::Binary => Binary, + Self::Utf8 => Utf8, + Self::Date32 => Date32, + Self::TimeMillis => Time32(TimeUnit::Millisecond), + Self::TimeMicros => Time64(TimeUnit::Microsecond), Self::TimestampMillis(is_utc) => { - DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) } Self::TimestampMicros(is_utc) => { - DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + } + Self::Interval => Interval(IntervalUnit::MonthDayNano), + Self::Fixed(size) => FixedSizeBinary(*size), + Self::List(f) => List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))), + Self::Struct(f) => Struct(f.iter().map(|x| x.field()).collect()), + Self::Enum(_symbols) => { + // Produce a Dictionary type with index = Int32, value = Utf8 + Dictionary(Box::new(Int32), Box::new(Utf8)) + } + Self::Map(values) => Map( + Arc::new(Field::new( + "entries", + Struct(Fields::from(vec![ + Field::new("key", Utf8, false), + values.field_with_name("value"), + ])), + false, + )), + false, + ), + Self::Decimal(precision, scale, size) => match size { + Some(s) if *s > 16 => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), + Some(s) => Decimal128(*precision as u8, scale.unwrap_or(0) as i8), + None if *precision <= DECIMAL128_MAX_PRECISION as usize + && scale.unwrap_or(0) <= DECIMAL128_MAX_SCALE as usize => + { + Decimal128(*precision as u8, scale.unwrap_or(0) as i8) + } + _ => Decimal256(*precision as u8, scale.unwrap_or(0) as i8), + }, + } + } + + /// Convert this `Codec` variant to an Avro `Schema<'a>`. + pub fn to_avro_schema<'a>(&'a self, name: &'a str) -> Schema<'a> { + match self { + Codec::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Codec::Boolean => Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), + Codec::Int32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Codec::Int64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + Codec::Float32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)), + Codec::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Codec::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), + Codec::Utf8 => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + // date32 => Avro int + logicalType=date + Codec::Date32 => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: Attributes { + logical_type: Some("date"), + additional: Default::default(), + }, + }), + // time-millis => Avro int with logicalType=time-millis + Codec::TimeMillis => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: Attributes { + logical_type: Some("time-millis"), + additional: Default::default(), + }, + }), + // time-micros => Avro long with logicalType=time-micros + Codec::TimeMicros => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type: Some("time-micros"), + additional: Default::default(), + }, + }), + // timestamp-millis => Avro long with logicalType=timestamp-millis or local-timestamp-millis + Codec::TimestampMillis(is_utc) => { + let logical_type = Some(if *is_utc { + "timestamp-millis" + } else { + "local-timestamp-millis" + }); + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type, + additional: Default::default(), + }, + }) + } + // timestamp-micros => Avro long with logicalType=timestamp-micros or local-timestamp-micros + Codec::TimestampMicros(is_utc) => { + let logical_type = Some(if *is_utc { + "timestamp-micros" + } else { + "local-timestamp-micros" + }); + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: Attributes { + logical_type, + additional: Default::default(), + }, + }) + } + Codec::Interval => Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("duration"), + additional: Default::default(), + }, + }), + Codec::Fixed(size) => { + // Convert Arrow FixedSizeBinary => Avro fixed with name & size + Schema::Complex(ComplexType::Fixed(Fixed { + name, + namespace: None, + aliases: vec![], + size: *size as usize, + attributes: Attributes::default(), + })) + } + Codec::List(item_type) => { + // Avro array with "items" recursively derived + let items_schema = item_type.to_avro_schema("items"); + Schema::Complex(ComplexType::Array(Array { + items: Box::new(items_schema), + attributes: Attributes::default(), + })) + } + Codec::Struct(fields) => { + // Avro record with nested fields + let record_fields = fields + .iter() + .map(|f| { + let child_schema = f.data_type().to_avro_schema(f.name()); + RecordField { + name: f.name(), + doc: None, + r#type: child_schema, + default: None, + } + }) + .collect(); + Schema::Complex(ComplexType::Record(Record { + name, + namespace: None, + doc: None, + aliases: vec![], + fields: record_fields, + attributes: Attributes::default(), + })) + } + Codec::Enum(symbols) => { + // If there's a namespace in metadata, we will apply it later in maybe_add_namespace. + Schema::Complex(ComplexType::Enum(Enum { + name, + namespace: None, + doc: None, + aliases: vec![], + symbols: symbols.iter().map(|s| s.as_str()).collect(), + default: None, + attributes: Attributes::default(), + })) } - Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::List(f) => { - DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + Codec::Map(values) => { + let val_schema = values.to_avro_schema("values"); + Schema::Complex(ComplexType::Map(Map { + values: Box::new(val_schema), + attributes: Attributes::default(), + })) + } + Codec::Decimal(precision, scale, size) => { + // If size is Some(n), produce Avro "fixed", else "bytes". + if let Some(n) = size { + Schema::Complex(ComplexType::Fixed(Fixed { + name, + namespace: None, + aliases: vec![], + size: *n, + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ("size", serde_json::json!(*n)), + ]), + }, + })) + } else { + // "type":"bytes", "logicalType":"decimal" + Schema::Type(crate::schema::Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::json!(*precision)), + ("scale", serde_json::json!(scale.unwrap_or(0))), + ]), + }, + }) + } } - Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), } } } @@ -203,8 +458,6 @@ impl<'a> Resolver<'a> { /// /// `name`: is name used to refer to `schema` in its parent /// `namespace`: an optional qualifier used as part of a type hierarchy -/// -/// See [`Resolver`] for more information fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, @@ -218,7 +471,7 @@ fn make_data_type<'a>( }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), Schema::Union(f) => { - // Special case the common case of nullable primitives + // Special case the common case of nullable primitives or single-type let null = f .iter() .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); @@ -251,7 +504,6 @@ fn make_data_type<'a>( }) }) .collect::>()?; - let field = AvroDataType { nullability: None, codec: Codec::Struct(fields), @@ -269,35 +521,122 @@ fn make_data_type<'a>( }) } ComplexType::Fixed(f) => { + // Possibly decimal with logicalType=decimal let size = f.size.try_into().map_err(|e| { ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; - + if let Some("decimal") = f.attributes.logical_type { + let precision = f + .attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires precision".to_string()) + })?; + let size_val = f + .attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .ok_or_else(|| { + ArrowError::ParseError("Decimal requires size".to_string()) + })?; + let scale = f + .attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .or(Some(0)); + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Decimal( + precision as usize, + Some(scale.unwrap_or(0) as usize), + Some(size_val as usize), + ), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } else { + let field = AvroDataType { + nullability: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + } + ComplexType::Enum(e) => { + let symbols = e + .symbols + .iter() + .map(|sym| sym.to_string()) + .collect::>(); let field = AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), + metadata: e.attributes.field_metadata(), + codec: Codec::Enum(symbols), + }; + resolver.register(e.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Map(m) => { + let values_data_type = make_data_type(m.values.as_ref(), namespace, resolver)?; + let field = AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(values_data_type)), }; - resolver.register(f.name, namespace, field.clone()); Ok(field) } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( - "Map of {m:?} not currently supported" - ))), }, Schema::Type(t) => { + // Possibly decimal, or other logical types let mut field = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - - // https://avro.apache.org/docs/1.11.1/specification/#logical-types match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + Some( + t.attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + ); + } + (Some("decimal"), c @ Codec::Binary) => { + *c = Codec::Decimal( + t.attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .unwrap_or(10) as usize, + Some( + t.attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize, + ), + None, + ); } (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, @@ -312,7 +651,7 @@ fn make_data_type<'a>( } (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { - // Insert unrecognized logical type into metadata map + // Insert unrecognized logical type into metadata field.metadata.insert("logicalType".into(), logical.into()); } (None, _) => {} @@ -327,3 +666,412 @@ fn make_data_type<'a>( } } } + +/// Convert an Arrow `Field` into an `AvroField`. +pub fn arrow_field_to_avro_field(arrow_field: &Field) -> AvroField { + let codec = arrow_type_to_codec(arrow_field.data_type()); + let nullability = if arrow_field.is_nullable() { + Some(Nullability::NullFirst) + } else { + None + }; + let mut metadata = arrow_field.metadata().clone(); + let avro_data_type = AvroDataType { + nullability, + metadata, + codec, + }; + AvroField { + name: arrow_field.name().clone(), + data_type: avro_data_type, + } +} + +/// Maps an Arrow `DataType` to a `Codec`. +fn arrow_type_to_codec(dt: &DataType) -> Codec { + match dt { + Null => Codec::Null, + Boolean => Codec::Boolean, + Int8 | Int16 | Int32 => Codec::Int32, + Int64 => Codec::Int64, + Float32 => Codec::Float32, + Float64 => Codec::Float64, + Utf8 => Codec::Utf8, + Binary | LargeBinary => Codec::Binary, + Date32 => Codec::Date32, + Time32(TimeUnit::Millisecond) => Codec::TimeMillis, + Time64(TimeUnit::Microsecond) => Codec::TimeMicros, + Timestamp(TimeUnit::Millisecond, None) => Codec::TimestampMillis(false), + Timestamp(TimeUnit::Microsecond, None) => Codec::TimestampMicros(false), + Timestamp(TimeUnit::Millisecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMillis(true) + } + Timestamp(TimeUnit::Microsecond, Some(tz)) if tz.as_ref() == "UTC" => { + Codec::TimestampMicros(true) + } + FixedSizeBinary(n) => Codec::Fixed(*n), + Decimal128(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(16)), + Decimal256(prec, scale) => Codec::Decimal(*prec as usize, Some(*scale as usize), Some(32)), + Dictionary(index_type, value_type) => { + if let Utf8 = **value_type { + Codec::Enum(vec![]) + } else { + // Fallback to Utf8 + Codec::Utf8 + } + } + Map(field, _keys_sorted) => { + if let Struct(child_fields) = field.data_type() { + let value_field = &child_fields[1]; + let sub_codec = arrow_type_to_codec(value_field.data_type()); + Codec::Map(Arc::new(AvroDataType { + nullability: value_field.is_nullable().then_some(Nullability::NullFirst), + metadata: value_field.metadata().clone(), + codec: sub_codec, + })) + } else { + Codec::Map(Arc::new(AvroDataType::from_codec(Codec::Utf8))) + } + } + Struct(child_fields) => { + let avro_fields: Vec = child_fields + .iter() + .map(|f_ref| arrow_field_to_avro_field(f_ref.as_ref())) + .collect(); + Codec::Struct(Arc::from(avro_fields)) + } + _ => Codec::Utf8, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::Field; + use serde_json::json; + use std::sync::Arc; + + #[test] + fn test_decimal256_tuple_variant_fixed() { + let c = arrow_type_to_codec(&Decimal256(60, 3)); + match c { + Codec::Decimal(p, s, Some(32)) => { + assert_eq!(p, 60); + assert_eq!(s, Some(3)); + } + _ => panic!("Expected decimal(60,3,Some(32))"), + } + let avro_dt = AvroDataType::from_codec(c); + let avro_schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 32, + "logicalType": "decimal", + "precision": 60, + "scale": 3 + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal128_tuple_variant_fixed() { + let c = Codec::Decimal(6, Some(2), Some(4)); + let dt = c.data_type(); + match dt { + Decimal128(p, s) => { + assert_eq!(p, 6); + assert_eq!(s, 2); + } + _ => panic!("Expected decimal(6,2) arrow type"), + } + let avro_dt = AvroDataType::from_codec(c); + let schema = avro_dt.to_avro_schema("FixedDec"); + let j = serde_json::to_value(&schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "FixedDec", + "aliases": [], + "size": 4, + "logicalType": "decimal", + "precision": 6, + "scale": 2, + }); + assert_eq!(j, expected); + } + + #[test] + fn test_decimal_size_decision() { + let codec = Codec::Decimal(10, Some(3), Some(16)); + let dt = codec.data_type(); + match dt { + Decimal128(precision, scale) => { + assert_eq!(precision, 10); + assert_eq!(scale, 3); + } + _ => panic!("Expected Decimal128"), + } + let codec = Codec::Decimal(18, Some(4), Some(32)); + let dt = codec.data_type(); + match dt { + Decimal256(precision, scale) => { + assert_eq!(precision, 18); + assert_eq!(scale, 4); + } + _ => panic!("Expected Decimal256"), + } + let codec = Codec::Decimal(8, Some(2), None); + let dt = codec.data_type(); + match dt { + Decimal128(precision, scale) => { + assert_eq!(precision, 8); + assert_eq!(scale, 2); + } + _ => panic!("Expected Decimal128"), + } + } + + #[test] + fn test_avro_data_type_new_and_from_codec() { + let dt1 = AvroDataType::new( + Codec::Int32, + Some(Nullability::NullFirst), + HashMap::from([("namespace".into(), "my.ns".into())]), + ); + let actual_str = format!("{:?}", dt1.nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + let actual_str2 = format!("{:?}", dt1.codec()); + let expected_str2 = format!("{:?}", &Codec::Int32); + assert_eq!(actual_str2, expected_str2); + assert_eq!(dt1.metadata.get("namespace"), Some(&"my.ns".to_string())); + let dt2 = AvroDataType::from_codec(Codec::Float64); + let actual_str4 = format!("{:?}", dt2.codec()); + let expected_str4 = format!("{:?}", &Codec::Float64); + assert_eq!(actual_str4, expected_str4); + assert!(dt2.metadata.is_empty()); + } + + #[test] + fn test_avro_data_type_field_with_name() { + let dt = AvroDataType::new( + Codec::Binary, + None, + HashMap::from([("something".into(), "else".into())]), + ); + let f = dt.field_with_name("bin_col"); + assert_eq!(f.name(), "bin_col"); + assert_eq!(f.data_type(), &Binary); + assert!(!f.is_nullable()); + assert_eq!(f.metadata().get("something"), Some(&"else".to_string())); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_record() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example".to_string()); + let fields = Arc::from(vec![ + AvroField { + name: "id".to_string(), + data_type: AvroDataType::from_codec(Codec::Int32), + }, + AvroField { + name: "label".to_string(), + data_type: AvroDataType::new( + Codec::Utf8, + Some(Nullability::NullFirst), + Default::default(), + ), + }, + ]); + let top_level = AvroDataType::new(Codec::Struct(fields), None, meta); + let avro_schema = top_level.to_avro_schema("TopRecord"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "record", + "name": "TopRecord", + "namespace": "com.example", + "doc": null, + "logicalType": null, + "aliases": [], + "fields": [ + { "name": "id", "doc": null, "type": "int" }, + { "name": "label", "doc": null, "type": ["null","string"] } + ], + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_enum() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.enum".to_string()); + + let enum_dt = AvroDataType::new( + Codec::Enum(vec!["A".to_string(), "B".to_string(), "C".to_string()]), + None, + meta, + ); + let avro_schema = enum_dt.to_avro_schema("MyEnum"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "enum", + "name": "MyEnum", + "logicalType": null, + "namespace": "com.example.enum", + "doc": null, + "aliases": [], + "symbols": ["A","B","C"] + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_data_type_to_avro_schema_with_namespace_fixed() { + let mut meta = HashMap::new(); + meta.insert("namespace".to_string(), "com.example.fixed".to_string()); + let fixed_dt = AvroDataType::new(Codec::Fixed(8), None, meta); + let avro_schema = fixed_dt.to_avro_schema("MyFixed"); + let json_val = serde_json::to_value(&avro_schema).unwrap(); + let expected = json!({ + "type": "fixed", + "name": "MyFixed", + "logicalType": null, + "namespace": "com.example.fixed", + "aliases": [], + "size": 8 + }); + assert_eq!(json_val, expected); + } + + #[test] + fn test_avro_field() { + let field_codec = AvroDataType::from_codec(Codec::Int64); + let avro_field = AvroField { + name: "long_col".to_string(), + data_type: field_codec.clone(), + }; + assert_eq!(avro_field.name(), "long_col"); + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Int64); + assert_eq!(actual_str, expected_str, "Codec debug output mismatch"); + let arrow_field = avro_field.field(); + assert_eq!(arrow_field.name(), "long_col"); + assert_eq!(arrow_field.data_type(), &Int64); + assert!(!arrow_field.is_nullable()); + } + + #[test] + fn test_arrow_field_to_avro_field() { + let arrow_field = Field::new("test_meta", Utf8, true).with_metadata(HashMap::from([( + "namespace".to_string(), + "arrow_meta_ns".to_string(), + )])); + let avro_field = arrow_field_to_avro_field(&arrow_field); + assert_eq!(avro_field.name(), "test_meta"); + let actual_str = format!("{:?}", avro_field.data_type().codec()); + let expected_str = format!("{:?}", &Codec::Utf8); + assert_eq!(actual_str, expected_str); + let actual_str = format!("{:?}", avro_field.data_type().nullability()); + let expected_str = format!("{:?}", Some(Nullability::NullFirst)); + assert_eq!(actual_str, expected_str); + assert_eq!( + avro_field.data_type().metadata.get("namespace"), + Some(&"arrow_meta_ns".to_string()) + ); + } + + #[test] + fn test_codec_struct() { + let fields = Arc::from(vec![ + AvroField { + name: "a".to_string(), + data_type: AvroDataType::from_codec(Codec::Boolean), + }, + AvroField { + name: "b".to_string(), + data_type: AvroDataType::from_codec(Codec::Float64), + }, + ]); + let codec = Codec::Struct(fields); + let dt = codec.data_type(); + match dt { + Struct(fields) => { + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "a"); + assert_eq!(fields[0].data_type(), &Boolean); + assert_eq!(fields[1].name(), "b"); + assert_eq!(fields[1].data_type(), &Float64); + } + _ => panic!("Expected Struct data type"), + } + } + + #[test] + fn test_codec_fixedsizebinary() { + let codec = Codec::Fixed(12); + let dt = codec.data_type(); + match dt { + FixedSizeBinary(n) => assert_eq!(n, 12), + _ => panic!("Expected FixedSizeBinary(12)"), + } + } + + #[test] + fn test_utc_timestamp_millis() { + let arrow_field = Field::new( + "utc_ts_ms", + Timestamp(TimeUnit::Millisecond, Some(Arc::from("UTC"))), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + assert!( + matches!(codec, Codec::TimestampMillis(true)), + "Expected Codec::TimestampMillis(true), got: {:?}", + codec + ); + } + + #[test] + fn test_utc_timestamp_micros() { + let arrow_field = Field::new( + "utc_ts_us", + Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC"))), + false, + ); + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + assert!( + matches!(codec, Codec::TimestampMicros(true)), + "Expected Codec::TimestampMicros(true), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_millis() { + let arrow_field = Field::new("local_ts_ms", Timestamp(TimeUnit::Millisecond, None), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + assert!( + matches!(codec, Codec::TimestampMillis(false)), + "Expected Codec::TimestampMillis(false), got: {:?}", + codec + ); + } + + #[test] + fn test_local_timestamp_micros() { + let arrow_field = Field::new("local_ts_us", Timestamp(TimeUnit::Microsecond, None), false); + let avro_field = arrow_field_to_avro_field(&arrow_field); + let codec = avro_field.data_type().codec(); + assert!( + matches!(codec, Codec::TimestampMicros(false)), + "Expected Codec::TimestampMicros(false), got: {:?}", + codec + ); + } +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index d01d681b7af0..ef3bd082d0e8 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -29,6 +29,7 @@ mod schema; mod compression; mod codec; +mod writer; #[cfg(test)] mod test_util { diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..9e38a78c63ec 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -14,7 +14,6 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. - use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; @@ -65,6 +64,7 @@ impl<'a> AvroCursor<'a> { Ok(val) } + /// Decode a zig-zag encoded Avro int (32-bit). #[inline] pub(crate) fn get_int(&mut self) -> Result { let varint = self.read_vlq()?; @@ -74,18 +74,20 @@ impl<'a> AvroCursor<'a> { Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } + /// Decode a zig-zag encoded Avro long (64-bit). #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } + /// Read a variable-length byte array from Avro (where the length is stored as an Avro long). pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { let len: usize = self.get_long()?.try_into().map_err(|_| { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -95,9 +97,10 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 32-bit float #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -107,15 +110,28 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 64-bit float #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( - "Unexpected EOF reading float".to_string(), + "Unexpected EOF reading double".to_string(), )); } let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 52a58cf63303..6fe4ae87bef3 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -16,44 +16,50 @@ // under the License. use crate::codec::{AvroDataType, Codec, Nullability}; -use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, + Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use std::collections::HashMap; use std::io::Read; use std::sync::Arc; -/// Decodes avro encoded data into [`RecordBatch`] +/// The default capacity used for internal buffers +const DEFAULT_CAPACITY: usize = 1024; + +/// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. pub struct RecordDecoder { schema: SchemaRef, fields: Vec, } impl RecordDecoder { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] expected to be a `Record`. pub fn try_new(data_type: &AvroDataType) -> Result { match Decoder::try_new(data_type)? { Decoder::Record(fields, encodings) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), fields: encodings, }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } + /// Return the [`SchemaRef`] describing the Arrow schema of rows produced by this decoder. pub fn schema(&self) -> &SchemaRef { &self.schema } - /// Decode `count` records from `buf` + /// Decode `count` Avro records from `buf`. + /// + /// This accumulates data in internal buffers. Once done reading, call + /// [`Self::flush`] to yield an Arrow [`RecordBatch`]. pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { @@ -64,7 +70,7 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the decoded records into a [`RecordBatch`] + /// Flush the accumulated data into a [`RecordBatch`], clearing internal state. pub fn flush(&mut self) -> Result { let arrays = self .fields @@ -76,30 +82,68 @@ impl RecordDecoder { } } +/// Decoder for Avro data of various shapes. #[derive(Debug)] enum Decoder { + /// Avro `null` Null(usize), + /// Avro `boolean` Boolean(BooleanBufferBuilder), + /// Avro `int` => i32 Int32(Vec), + /// Avro `long` => i64 Int64(Vec), + /// Avro `float` => f32 Float32(Vec), + /// Avro `double` => f64 Float64(Vec), + /// Avro `date` => Date32 Date32(Vec), + /// Avro `time-millis` => Time32(Millisecond) TimeMillis(Vec), + /// Avro `time-micros` => Time64(Microsecond) TimeMicros(Vec), + /// Avro `timestamp-millis` (bool = UTC?) TimestampMillis(bool, Vec), + /// Avro `timestamp-micros` (bool = UTC?) TimestampMicros(bool, Vec), + /// Avro `bytes` => Arrow Binary Binary(OffsetBufferBuilder, Vec), + /// Avro `string` => Arrow String String(OffsetBufferBuilder, Vec), + /// Avro `fixed(n)` => Arrow `FixedSizeBinaryArray` + Fixed(i32, Vec), + /// Avro `interval` => Arrow `IntervalMonthDayNanoType` (12 bytes) + Interval(Vec), + /// Avro `array` List(FieldRef, OffsetBufferBuilder, Box), + /// Avro `record` Record(Fields, Vec), + /// Avro union that includes `null` Nullable(Nullability, NullBufferBuilder, Box), + /// Avro `enum` => Dictionary(int32 -> string) + Enum(Vec, Vec), + /// Avro `map` + Map( + FieldRef, + OffsetBufferBuilder, + OffsetBufferBuilder, + Vec, + Box, + usize, + ), + /// Avro decimal => Arrow decimal + Decimal(usize, Option, Option, DecimalBuilder), } impl Decoder { - fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); + /// Checks if the Decoder is nullable, i.e. wrapped in `Nullable`. + fn is_nullable(&self) -> bool { + matches!(self, Self::Nullable(_, _, _)) + } + /// Create a `Decoder` from an [`AvroDataType`]. + fn try_new(data_type: &AvroDataType) -> Result { let decoder = match data_type.codec() { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), @@ -124,163 +168,418 @@ impl Decoder { Codec::TimestampMicros(is_utc) => { Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), + Codec::Fixed(n) => Self::Fixed(*n, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Interval => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), Codec::List(item) => { - let decoder = Self::try_new(item)?; + let item_decoder = Box::new(Self::try_new(item)?); Self::List( Arc::new(item.field_with_name("item")), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + item_decoder, ) } - Codec::Struct(fields) => { - let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut encodings = Vec::with_capacity(fields.len()); - for avro_field in fields.iter() { - let encoding = Self::try_new(avro_field.data_type())?; + Codec::Struct(avro_fields) => { + let mut arrow_fields = Vec::with_capacity(avro_fields.len()); + let mut decoders = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { + let d = Self::try_new(avro_field.data_type())?; arrow_fields.push(avro_field.field()); - encodings.push(encoding); + decoders.push(d); } - Self::Record(arrow_fields.into(), encodings) + Self::Record(arrow_fields.into(), decoders) + } + Codec::Enum(symbols) => { + Self::Enum(symbols.clone(), Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Map(value_type) => { + let map_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(value_type.field_with_name("value")), + ])), + false, + )); + Self::Map( + map_field, + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + Box::new(Self::try_new(value_type)?), + 0, + ) + } + Codec::Decimal(precision, scale, size) => { + let builder = DecimalBuilder::new(*precision, *scale, *size)?; + Self::Decimal(*precision, *scale, *size, builder) } }; - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( - nullability, + // Wrap in Nullable if needed + match data_type.nullability() { + Some(nb) => Ok(Self::Nullable( + nb, NullBufferBuilder::new(DEFAULT_CAPACITY), Box::new(decoder), - ), - None => decoder, - }) + )), + None => Ok(decoder), + } } - /// Append a null record + /// Append a null to this decoder. fn append_null(&mut self) { match self { - Self::Null(count) => *count += 1, + Self::Null(n) => *n += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), - Self::List(_, offsets, e) => { - offsets.push_length(0); - e.append_null(); - } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + Self::Float32(v) => v.push(0.0), + Self::Float64(v) => v.push(0.0), + Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), + Self::Fixed(fsize, buf) => { + // For a null, push `fsize` zeroed bytes + buf.extend(std::iter::repeat(0u8).take(*fsize as usize)); + } + Self::Interval(intervals) => { + // null => store a 12-byte zero => months=0, days=0, nanos=0 + intervals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); + } + Self::List(_, off, child) => { + off.push_length(0); + child.append_null(); + } + Self::Record(_, children) => { + for c in children.iter_mut() { + c.append_null(); + } + } + Self::Enum(_, indices) => indices.push(0), + Self::Map(_, key_off, map_off, _, _, entry_count) => { + key_off.push_length(0); + map_off.push_length(*entry_count); + } + Self::Decimal(_, _, _, builder) => { + let _ = builder.append_null(); + } + Self::Nullable(_, _, _) => { /* The null bit is stored in the NullBufferBuilder */ } } } - /// Decode a single record from `buf` + /// Decode a single row of data from `buf`. fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { - Self::Null(x) => *x += 1, + Self::Null(count) => *count += 1, Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { - values.push(buf.get_int()?) - } - Self::Int64(values) - | Self::TimeMicros(values) - | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + Self::Int32(values) => values.push(buf.get_int()?), + Self::Date32(values) => values.push(buf.get_int()?), + Self::Int64(values) => values.push(buf.get_long()?), + Self::TimeMillis(values) => values.push(buf.get_int()?), + Self::TimeMicros(values) => values.push(buf.get_long()?), + Self::TimestampMillis(_, values) => values.push(buf.get_long()?), + Self::TimestampMicros(_, values) => values.push(buf.get_long()?), Self::Float32(values) => values.push(buf.get_float()?), Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) | Self::String(offsets, values) => { - let data = buf.get_bytes()?; - offsets.push_length(data.len()); - values.extend_from_slice(data); - } - Self::List(_, _, _) => { - return Err(ArrowError::NotYetImplemented( - "Decoding ListArray".to_string(), - )) - } - Self::Record(_, encodings) => { - for encoding in encodings { - encoding.decode(buf)?; + Self::Binary(off, data) | Self::String(off, data) => { + let bytes = buf.get_bytes()?; + off.push_length(bytes.len()); + data.extend_from_slice(bytes); + } + Self::Fixed(fsize, accum) => accum.extend_from_slice(buf.get_fixed(*fsize as usize)?), + Self::Interval(intervals) => { + let raw = buf.get_fixed(12)?; + let months = i32::from_le_bytes(raw[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(raw[4..8].try_into().unwrap()); + let millis = i32::from_le_bytes(raw[8..12].try_into().unwrap()); + let nanos = millis as i64 * 1_000_000; + let val = IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }; + intervals.push(val); + } + Self::List(_, off, child) => { + let total_items = read_array_blocks(buf, |b| child.decode(b))?; + off.push_length(total_items); + } + Self::Record(_, children) => { + for c in children.iter_mut() { + c.decode(buf)?; } } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Self::Nullable(_, nulls, child) => match buf.get_int()? { + 0 => { + nulls.append(true); + child.decode(buf)?; + } + 1 => { + nulls.append(false); + child.append_null(); + } + other => { + return Err(ArrowError::ParseError(format!( + "Unsupported union branch index {other} for Nullable" + ))); } + }, + Self::Enum(_, indices) => indices.push(buf.get_int()?), + Self::Map(_, key_off, map_off, key_data, val_decoder, entry_count) => { + let newly_added = read_map_blocks(buf, |b| { + let kb = b.get_bytes()?; + key_off.push_length(kb.len()); + key_data.extend_from_slice(kb); + val_decoder.decode(b) + })?; + *entry_count += newly_added; + map_off.push_length(*entry_count); + } + Self::Decimal(_, _, size, builder) => { + let bytes = match *size { + Some(sz) => buf.get_fixed(sz)?, + None => buf.get_bytes()?, + }; + builder.append_bytes(bytes)?; } } Ok(()) } - /// Flush decoded records to an [`ArrayRef`] + /// Flush buffered data into an [`ArrayRef`], optionally applying `nulls`. fn flush(&mut self, nulls: Option) -> Result { - Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, - Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), - Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), - Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::TimeMillis(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimeMicros(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimestampMillis(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::TimestampMicros(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - - Self::Binary(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) + match self { + // For a nullable wrapper => flush the child with the built null buffer + Self::Nullable(_, nb, child) => { + let mask = nb.finish(); + child.flush(mask) + } + // Null => produce NullArray + Self::Null(len) => { + let count = std::mem::replace(len, 0); + Ok(Arc::new(NullArray::new(count))) + } + // boolean => flush to BooleanArray + Self::Boolean(b) => { + let bits = b.finish(); + Ok(Arc::new(BooleanArray::new(bits, nulls))) + } + // int32 => flush to Int32Array + Self::Int32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // date32 => flush to Date32Array + Self::Date32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // int64 => flush to Int64Array + Self::Int64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-millis => Time32Millisecond + Self::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // time-micros => Time64Microsecond + Self::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) } - Self::String(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) + // timestamp-millis => TimestampMillisecond + Self::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) } - Self::List(field, offsets, values) => { - let values = values.flush(None)?; - let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + // timestamp-micros => TimestampMicrosecond + Self::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr)) } - Self::Record(fields, encodings) => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + // float32 => flush to Float32Array + Self::Float32(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) } - }) + // float64 => flush to Float64Array + Self::Float64(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr)) + } + // Avro bytes => BinaryArray + Self::Binary(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); + Ok(Arc::new(BinaryArray::new(offsets, values, nulls))) + } + // Avro string => StringArray + Self::String(off, data) => { + let offsets = flush_offsets(off); + let values = flush_values(data).into(); + Ok(Arc::new(StringArray::new(offsets, values, nulls))) + } + // Avro fixed => FixedSizeBinaryArray + Self::Fixed(fsize, raw) => { + let size = *fsize; + let buf: Buffer = flush_values(raw).into(); + let total_len = buf.len() / (size as usize); + let array = FixedSizeBinaryArray::try_new(size, buf, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(array)) + } + // Avro interval => IntervalMonthDayNanoType + Self::Interval(vals) => { + let data_len = vals.len(); + let mut builder = + PrimitiveBuilder::::with_capacity(data_len); + for v in vals.drain(..) { + builder.append_value(v); + } + let arr = builder + .finish() + .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + // "merge" the newly built array with the nulls + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = unsafe { arr_data.build_unchecked() }; + Ok(Arc::new(PrimitiveArray::::from( + arr_data, + ))) + } else { + Ok(Arc::new(arr)) + } + } + // Avro array => ListArray + Self::List(field, off, item_dec) => { + let child_arr = item_dec.flush(None)?; + let offsets = flush_offsets(off); + let arr = ListArray::new(field.clone(), offsets, child_arr, nulls); + Ok(Arc::new(arr)) + } + // Avro record => StructArray + Self::Record(fields, children) => { + let mut arrays = Vec::with_capacity(children.len()); + for c in children.iter_mut() { + let a = c.flush(None)?; + arrays.push(a); + } + Ok(Arc::new(StructArray::new(fields.clone(), arrays, nulls))) + } + // Avro enum => DictionaryArray utf8> + Self::Enum(symbols, indices) => { + let dict_values = StringArray::from_iter_values(symbols.iter()); + let idxs: Int32Array = match nulls { + Some(b) => { + let buff = Buffer::from_slice_ref(&indices); + PrimitiveArray::::try_new( + arrow_buffer::ScalarBuffer::from(buff), + Some(b), + )? + } + None => Int32Array::from_iter_values(indices.iter().cloned()), + }; + let dict = DictionaryArray::::try_new(idxs, Arc::new(dict_values))?; + indices.clear(); // reset + Ok(Arc::new(dict)) + } + // Avro map => MapArray + Self::Map(field, key_off, map_off, key_data, val_dec, entry_count) => { + let moff = flush_offsets(map_off); + let koff = flush_offsets(key_off); + let kd = flush_values(key_data).into(); + let val_arr = val_dec.flush(None)?; + let is_nullable = matches!(**val_dec, Self::Nullable(_, _, _)); + let key_arr = StringArray::new(koff, kd, None); + let struct_fields = vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new( + "value", + val_arr.data_type().clone(), + is_nullable, + )), + ]; + let entries = StructArray::new( + Fields::from(struct_fields), + vec![Arc::new(key_arr), val_arr], + None, + ); + let map_arr = MapArray::new(field.clone(), moff, entries, nulls, false); + *entry_count = 0; + Ok(Arc::new(map_arr)) + } + // Avro decimal => Arrow decimal + Self::Decimal(prec, sc, sz, builder) => { + let precision = *prec; + let scale = sc.unwrap_or(0); + let new_builder = DecimalBuilder::new(precision, *sc, *sz)?; + let old_builder = std::mem::replace(builder, new_builder); + let arr = old_builder.finish(nulls, precision, scale)?; + Ok(arr) + } + } } } -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +/// Decode an Avro array in blocks until a 0 block_count signals end. +fn read_array_blocks( + buf: &mut AvroCursor, + mut decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let mut total_items = 0usize; + loop { + let block_count = buf.get_long()?; + match block_count { + 0 => break, // If block_count is 0, exit the loop + n if n < 0 => { + // If block_count is negative + let item_count = (-n) as usize; + let _block_size = buf.get_long()?; // Read but ignore block size + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } + n => { + // If block_count is positive + let item_count = n as usize; + for _ in 0..item_count { + decode_item(buf)?; + } + total_items += item_count; + } + } + } + Ok(total_items) } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +/// Decode an Avro map in blocks until 0 block_count => end. +fn read_map_blocks( + buf: &mut AvroCursor, + mut decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let block_count = buf.get_long()?; + if block_count <= 0 { + Ok(0) + } else { + let n = block_count as usize; + for _ in 0..n { + decode_entry(buf)?; + } + Ok(n) + } } +/// Flush a [`Vec`] of primitive values to a [`PrimitiveArray`], applying optional `nulls`. #[inline] fn flush_primitive( values: &mut Vec, @@ -289,4 +588,677 @@ fn flush_primitive( PrimitiveArray::new(flush_values(values).into(), nulls) } -const DEFAULT_CAPACITY: usize = 1024; +/// Flush an [`OffsetBufferBuilder`]. +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +/// Take ownership of `values`. +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +/// A builder for Avro decimal, either 128-bit or 256-bit. +#[derive(Debug)] +enum DecimalBuilder { + Decimal128(Decimal128Builder), + Decimal256(Decimal256Builder), +} + +impl DecimalBuilder { + /// Create a new DecimalBuilder given precision, scale, and optional byte-size (`fixed`). + fn new( + precision: usize, + scale: Option, + size: Option, + ) -> Result { + match size { + Some(s) if s > 16 && s <= 32 => Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), + Some(s) if s <= 16 => Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )), + None => { + // infer from precision + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new() + .with_precision_and_scale(precision as u8, scale.unwrap_or(0) as i8)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) + } + } + _ => Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {:?}", + size + ))), + } + } + + /// Append sign-extended bytes to this decimal builder + fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let padded = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(padded); + b.append_value(val); + } + Self::Decimal256(b) => { + let padded = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(padded); + b.append_value(val); + } + } + Ok(()) + } + + /// Append a null decimal value (0) + fn append_null(&mut self) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let zero = [0u8; 16]; + b.append_value(i128::from_be_bytes(zero)); + } + Self::Decimal256(b) => { + let zero = [0u8; 32]; + b.append_value(i256::from_be_bytes(zero)); + } + } + Ok(()) + } + + /// Finish building the decimal array, returning an [`ArrayRef`]. + fn finish( + self, + nulls: Option, + precision: usize, + scale: usize, + ) -> Result { + match self { + Self::Decimal128(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal128Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + Self::Decimal256(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal256Array::new(vals, nulls) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + } + } +} + +/// Sign-extend `raw` to 16 bytes. +fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let extended = sign_extend(raw, 16); + if extended.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 16 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 16]; + arr.copy_from_slice(&extended); + Ok(arr) +} + +/// Sign-extend `raw` to 32 bytes. +fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let extended = sign_extend(raw, 32); + if extended.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 32 bytes, got {} bytes", + extended.len() + ))); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&extended); + Ok(arr) +} + +/// Sign-extend the first byte to produce `target_len` bytes total. +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut out = Vec::with_capacity(target_len); + if sign_bit != 0 { + out.resize(target_len - raw.len(), 0xFF); + } else { + out.resize(target_len - raw.len(), 0x00); + } + out.extend_from_slice(raw); + out +} + +/// Convenience helper to build a field with `name`, `DataType` and `nullable`. +fn field_with_type(name: &str, dt: DataType, nullable: bool) -> FieldRef { + Arc::new(ArrowField::new(name, dt, nullable)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ + cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, + IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, + }; + + // --------------- + // Zig-Zag Helpers + // --------------- + fn encode_avro_int(value: i32) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 31); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_long(value: i64) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 63); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_bytes(bytes: &[u8]) -> Vec { + let mut buf = encode_avro_long(bytes.len() as i64); + buf.extend_from_slice(bytes); + buf + } + + // ----------------- + // Test Fixed + // ----------------- + #[test] + fn test_fixed_decoding() { + // `fixed(4)` => Arrow FixedSizeBinary(4) + let dt = AvroDataType::from_codec(Codec::Fixed(4)); + let mut dec = Decoder::try_new(&dt).unwrap(); + // 2 rows, each row => 4 bytes + let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; + let row2 = [0x01, 0x23, 0x45, 0x67]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 2); + assert_eq!(fsb.value_length(), 4); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(1), row2); + } + + #[test] + fn test_fixed_with_nulls() { + // Avro union => [ fixed(2), null] + let dt = AvroDataType::from_codec(Codec::Fixed(2)); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // Decode 3 rows: row1 => branch=0 => [0x00], then 2 bytes + // row2 => branch=1 => null => [0x02] + // row3 => branch=0 => 2 bytes + let row1 = [0x11, 0x22]; + let row3 = [0x55, 0x66]; + let mut data = Vec::new(); + // row1 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // row2 => union=1 => null + data.extend_from_slice(&encode_avro_int(1)); + // row3 => union=0 => child => 2 bytes + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + dec.decode(&mut cursor).unwrap(); // row3 + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 3); + assert!(fsb.is_valid(0)); + assert!(!fsb.is_valid(1)); + assert!(fsb.is_valid(2)); + assert_eq!(fsb.value_length(), 2); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(2), row3); + } + + // ----------------- + // Test Interval + // ----------------- + #[test] + fn test_interval_decoding() { + // Avro interval => 12 bytes => [ months i32, days i32, ms i32 ] + // decode 2 rows => row1 => months=1, days=2, ms=100 => row2 => months=-1, days=10, ms=9999 + let dt = AvroDataType::from_codec(Codec::Interval); + let mut dec = Decoder::try_new(&dt).unwrap(); + // row1 => months=1 => 01,00,00,00, days=2 => 02,00,00,00, ms=100 => 64,00,00,00 + // row2 => months=-1 => 0xFF,0xFF,0xFF,0xFF, days=10 => 0x0A,0x00,0x00,0x00, ms=9999 => 0x0F,0x27,0x00,0x00 + let row1 = [ + 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x64, 0x00, 0x00, 0x00, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0x0A, 0x00, 0x00, 0x00, 0x0F, 0x27, 0x00, 0x00, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + // row0 => months=1, days=2, ms=100 => nanos=100_000_000 + // row1 => months=-1, days=10, ms=9999 => nanos=9999_000_000 + let val0 = intervals.value(0); + assert_eq!(val0.months, 1); + assert_eq!(val0.days, 2); + assert_eq!(val0.nanoseconds, 100_000_000); + let val1 = intervals.value(1); + assert_eq!(val1.months, -1); + assert_eq!(val1.days, 10); + assert_eq!(val1.nanoseconds, 9_999_000_000); + } + + #[test] + fn test_interval_decoding_with_nulls() { + // Avro union => [ interval, null] + let dt = AvroDataType::from_codec(Codec::Interval); + let child = Decoder::try_new(&dt).unwrap(); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + // We'll decode 2 rows: row1 => interval => months=2, days=3, ms=500 => row2 => null + // row1 => union=0 => child => 12 bytes + // row2 => union=1 => null => no data + let row1 = [ + 0x02, 0x00, 0x00, 0x00, // months=2 + 0x03, 0x00, 0x00, 0x00, // days=3 + 0xF4, 0x01, 0x00, 0x00, + ]; // ms=500 => nanos=500_000_000 + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); // union=0 => child + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); // union=1 => null + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // row1 + dec.decode(&mut cursor).unwrap(); // row2 => null + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + assert!(intervals.is_valid(0)); + assert!(!intervals.is_valid(1)); + let val0 = intervals.value(0); + assert_eq!(val0.months, 2); + assert_eq!(val0.days, 3); + assert_eq!(val0.nanoseconds, 500_000_000); + } + + // ------------------- + // Tests for Enum + // ------------------- + #[test] + fn test_enum_decoding() { + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&enum_dt).unwrap(); + // Encode the indices [1, 0, 2] => zigzag => 1->2, 0->0, 2->4 + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(1)); // => [2] + data.extend_from_slice(&encode_avro_int(0)); // => [0] + data.extend_from_slice(&encode_avro_int(2)); // => [4] + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); // => GREEN + decoder.decode(&mut cursor).unwrap(); // => RED + decoder.decode(&mut cursor).unwrap(); // => BLUE + let array = decoder.flush(None).unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_enum_decoding_with_nulls() { + // Union => [Enum(...), null] + // "child" => branch_index=0 => [0x00], "null" => 1 => [0x02] + let symbols = vec!["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols.clone())); + let mut inner_decoder = Decoder::try_new(&enum_dt).unwrap(); + let mut nullable_decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner_decoder), + ); + // Indices: [1, null, 2] => in Avro union + let mut data = Vec::new(); + // Row1 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=1 => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row2 => union branch=1 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => union branch=0 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + // Then child's enum index=2 => [0x04] + data.extend_from_slice(&encode_avro_int(2)); + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); // => GREEN + nullable_decoder.decode(&mut cursor).unwrap(); // => null + nullable_decoder.decode(&mut cursor).unwrap(); // => BLUE + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_arr.len(), 3); + // [GREEN, null, BLUE] + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); + assert!(dict_arr.is_valid(2)); + let keys = dict_arr.keys(); + // keys.value(0) => 1 => GREEN + // keys.value(2) => 2 => BLUE + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + // ------------------- + // Tests for Map + // ------------------- + #[test] + fn test_map_decoding_one_entry() { + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode a single map with one entry: {"hello": "world"} + let mut data = Vec::new(); + // block_count=1 => zigzag => [0x02] + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key + data.extend_from_slice(&encode_avro_bytes(b"world")); // value + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.value_length(0), 1); + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 1); + let key_arr = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let val_arr = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key_arr.value(0), "hello"); + assert_eq!(val_arr.value(0), "world"); + } + + #[test] + fn test_map_decoding_empty() { + // block_count=0 => empty map + let value_type = AvroDataType::from_codec(Codec::Utf8); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode an empty map => block_count=0 => [0x00] + let data = encode_avro_long(0); + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 0); + } + + // ------------------- + // Tests for Decimal + // ------------------- + #[test] + fn test_decimal_decoding_fixed128() { + let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + // Row1 => 123.45 => unscaled=12345 => i128 0x000...3039 + // Row2 => -1.23 => unscaled=-123 => i128 0xFFFF...FF85 + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls() { + // Avro union => [ Decimal(4,1), null ] + // child => index=0 => [0x00], null => index=1 => [0x02] + let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + ); + // Decode three rows: [123.4, null, -123.4] + let mut data = Vec::new(); + // Row1 => child => [0x00], then decimal => e.g. 0x04D2 => 1234 => "123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00], then decimal => 0xFB2E => -1234 => "-123.4" + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + // Avro union => [Decimal(6,2,16), null] + let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); + let mut inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + ); + // Decode [1234.56, null, -1234.56] + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0xE2, 0x40, + ]; + let row3 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, + 0x1D, 0xC0, + ]; + let mut data = Vec::new(); + // Row1 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + // Row2 => null => [0x02] + data.extend_from_slice(&encode_avro_int(1)); + // Row3 => child => [0x00] + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + + // ------------------- + // Tests for List + // ------------------- + #[test] + fn test_list_decoding() { + // Avro array => block1(count=2), item1, item2, block2(count=0 => end) + // + // 1. Create 2 rows: + // Row1 => [10, 20] + // Row2 => [ ] + // + // 2. flush => should yield 2-element array => first row has 2 items, second row has 0 items + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // Row1 => block_count=2 => item=10 => item=20 => block_count=0 => end + // - 2 => zigzag => [0x04] + // - item=10 => zigzag => [0x14] + // - item=20 => zigzag => [0x28] + // - 0 => [0x00] + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(2)); // block_count=2 + row1.extend_from_slice(&encode_avro_int(10)); // item=10 + row1.extend_from_slice(&encode_avro_int(20)); // item=20 + row1.extend_from_slice(&encode_avro_long(0)); // end of array + + // Row2 => block_count=0 => empty array + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&row1); + decoder.decode(&mut cursor).unwrap(); + let mut cursor2 = AvroCursor::new(&row2); + decoder.decode(&mut cursor2).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 2); + // row0 => 2 items => [10, 20] + // row1 => 0 items + let offsets = list_arr.value_offsets(); + assert_eq!(offsets, &[0, 2, 2]); + let values = list_arr.values(); + let int_arr = values.as_primitive::(); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 10); + assert_eq!(int_arr.value(1), 20); + } + + #[test] + fn test_list_decoding_with_negative_block_count() { + // Start with single row => [1, 2, 3] + // We'll store them in a single negative block => block_count=-3 => #items=3 + // Then read block_size => let's pretend it's 9 bytes, etc. Then the items. + // Then a block_count=0 => done + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::List(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt).unwrap(); + // block_count=-3 => zigzag => (-3 << 1) ^ (-3 >> 63) + // => -6 ^ -1 => ... + // Encode directly with `encode_avro_long(-3)`. + let mut data = encode_avro_long(-3); + // Next => block_size => let's pretend 12 => encode_avro_long(12) + data.extend_from_slice(&encode_avro_long(12)); + // Then 3 items => [1, 2, 3] + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(3)); + // Then block_count=0 => done + data.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 1); + assert_eq!(list_arr.value_length(0), 3); + let values = list_arr.values().as_primitive::(); + assert_eq!(values.len(), 3); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + assert_eq!(values.value(2), 3); + } +} diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index a9d91e47948b..8e3f23ffbb5e 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -123,28 +123,28 @@ pub enum ComplexType<'a> { pub struct Record<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] - pub fields: Vec>, + pub fields: Vec>, #[serde(flatten)] pub attributes: Attributes<'a>, } /// A field within a [`Record`] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Field<'a> { +pub struct RecordField<'a> { #[serde(borrow)] pub name: &'a str, #[serde(borrow, default)] pub doc: Option<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub default: Option<&'a str>, } @@ -155,7 +155,7 @@ pub struct Field<'a> { pub struct Enum<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub doc: Option<&'a str>, @@ -163,7 +163,7 @@ pub struct Enum<'a> { pub aliases: Vec<&'a str>, #[serde(borrow)] pub symbols: Vec<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub default: Option<&'a str>, #[serde(flatten)] pub attributes: Attributes<'a>, @@ -198,7 +198,7 @@ pub struct Map<'a> { pub struct Fixed<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, @@ -237,7 +237,7 @@ mod tests { "logicalType":"timestamp-micros" }"#, ) - .unwrap(); + .unwrap(); let timestamp = Type { r#type: TypeName::Primitive(PrimitiveType::Long), @@ -260,7 +260,7 @@ mod tests { "scale":2 }"#, ) - .unwrap(); + .unwrap(); let decimal = ComplexType::Fixed(Fixed { name: "fixed", @@ -300,7 +300,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -309,7 +309,7 @@ mod tests { namespace: None, doc: None, aliases: vec![], - fields: vec![Field { + fields: vec![RecordField { name: "value", doc: None, r#type: Schema::Union(vec![ @@ -333,7 +333,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -343,13 +343,13 @@ mod tests { doc: None, aliases: vec!["LinkedLongs"], fields: vec![ - Field { + RecordField { name: "value", doc: None, r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, - Field { + RecordField { name: "next", doc: None, r#type: Schema::Union(vec![ @@ -392,7 +392,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -402,7 +402,7 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "id", doc: None, r#type: Schema::Union(vec![ @@ -411,7 +411,7 @@ mod tests { ]), default: None, }, - Field { + RecordField { name: "timestamp_col", doc: None, r#type: Schema::Union(vec![ @@ -453,7 +453,7 @@ mod tests { ] }"#, ) - .unwrap(); + .unwrap(); assert_eq!( schema, @@ -463,7 +463,7 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "clientHash", doc: None, r#type: Schema::Complex(ComplexType::Fixed(Fixed { @@ -475,7 +475,7 @@ mod tests { })), default: None, }, - Field { + RecordField { name: "clientProtocol", doc: None, r#type: Schema::Union(vec![ @@ -484,13 +484,13 @@ mod tests { ]), default: None, }, - Field { + RecordField { name: "serverHash", doc: None, r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, }, - Field { + RecordField { name: "meta", doc: None, r#type: Schema::Union(vec![ diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs new file mode 100644 index 000000000000..635333718ac7 --- /dev/null +++ b/arrow-avro/src/writer/mod.rs @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod schema; +mod vlq; + +#[cfg(test)] +mod test { + use arrow_array::RecordBatch; + use std::fs::File; + use std::io::BufWriter; + + fn write_file(file: &str, batch: &RecordBatch) { + let file = File::open(file).unwrap(); + let mut writer = BufWriter::new(file); + } +} diff --git a/arrow-avro/src/writer/schema.rs b/arrow-avro/src/writer/schema.rs new file mode 100644 index 000000000000..521ea9e6b107 --- /dev/null +++ b/arrow-avro/src/writer/schema.rs @@ -0,0 +1,277 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::codec::{AvroDataType, AvroField, Codec}; +use crate::schema::Schema; +use arrow_array::RecordBatch; +use std::sync::Arc; + +fn record_batch_to_avro_schema<'a>( + batch: &'a RecordBatch, + record_name: &'a str, + top_level_data_type: &'a AvroDataType, +) -> Schema<'a> { + top_level_data_type.to_avro_schema(record_name) +} + +pub fn to_avro_json_schema( + batch: &RecordBatch, + record_name: &str, +) -> Result { + let avro_fields: Vec = batch + .schema() + .fields() + .iter() + .map(|arrow_field| crate::codec::arrow_field_to_avro_field(arrow_field)) + .collect(); + let top_level_data_type = AvroDataType::from_codec(Codec::Struct(Arc::from(avro_fields))); + let avro_schema = record_batch_to_avro_schema(batch, record_name, &top_level_data_type); + serde_json::to_string_pretty(&avro_schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StructArray}; + use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema}; + use serde_json::{json, Value}; + use std::sync::Arc; + + #[test] + fn test_record_batch_to_avro_schema_basic() { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) + .expect("Failed to create RecordBatch"); + + // Convert the batch -> Avro `Schema` + let avro_schema = to_avro_json_schema(&batch, "MyTestRecord") + .expect("Failed to convert RecordBatch to Avro JSON schema"); + let actual_json: Value = serde_json::from_str(&avro_schema) + .expect("Invalid JSON returned by to_avro_json_schema"); + + let expected_json = json!({ + "type": "record", + "name": "MyTestRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "doc": null, + "type": "int" + }, + { + "name": "name", + "doc": null, + "type": ["null", "string"] + } + ] + }); + assert_eq!( + actual_json, expected_json, + "Avro Schema JSON does not match expected" + ); + } + + #[test] + fn test_to_avro_json_schema_basic() { + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("desc", DataType::Utf8, true), + ])); + let col_id = Arc::new(Int32Array::from(vec![10, 20, 30])); + let col_desc = Arc::new(StringArray::from(vec![Some("a"), Some("b"), None])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_desc]) + .expect("Failed to create RecordBatch"); + let json_schema_string = to_avro_json_schema(&batch, "AnotherTestRecord") + .expect("Failed to convert RecordBatch to Avro JSON schema"); + let actual_json: Value = serde_json::from_str(&json_schema_string) + .expect("Invalid JSON returned by to_avro_json_schema"); + let expected_json = json!({ + "type": "record", + "name": "AnotherTestRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + }, + { + "name": "desc", + "type": ["null", "string"], + "doc": null, + } + ] + }); + assert_eq!( + actual_json, expected_json, + "JSON schema mismatch for to_avro_json_schema" + ); + } + + #[test] + fn test_to_avro_json_schema_single_nonnull_int() { + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "id", + DataType::Int32, + false, + )])); + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let batch = + RecordBatch::try_new(arrow_schema, vec![col_id]).expect("Failed to create RecordBatch"); + let avro_json_string = to_avro_json_schema(&batch, "MySingleIntRecord") + .expect("Failed to generate Avro JSON schema"); + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); + let expected_json = json!({ + "type": "record", + "name": "MySingleIntRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + } + ] + }); + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } + + #[test] + fn test_to_avro_json_schema_two_fields_nullable_string() { + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ])); + let col_id = Arc::new(Int32Array::from(vec![1, 2, 3])); + let col_name = Arc::new(StringArray::from(vec![Some("foo"), None, Some("bar")])); + let batch = RecordBatch::try_new(arrow_schema, vec![col_id, col_name]) + .expect("Failed to create RecordBatch"); + let avro_json_string = + to_avro_json_schema(&batch, "MyRecord").expect("Failed to generate Avro JSON schema"); + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); + let expected_json = json!({ + "type": "record", + "name": "MyRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "id", + "type": "int", + "doc": null, + }, + { + "name": "name", + "doc": null, + "type": [ + "null", + "string", + ] + } + ] + }); + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } + + #[test] + fn test_to_avro_json_schema_nested_struct() { + let inner_fields = Fields::from(vec![ + Field::new("inner_int", DataType::Int32, false), + Field::new("inner_str", DataType::Utf8, true), + ]); + let arrow_schema = Arc::new(arrow_schema::Schema::new(vec![Field::new( + "my_struct", + DataType::Struct(inner_fields), + true, + )])); + let inner_int_col = Arc::new(Int32Array::from(vec![10, 20, 30])) as ArrayRef; + let inner_str_col = + Arc::new(StringArray::from(vec![Some("a"), None, Some("c")])) as ArrayRef; + let fields_arrays = vec![ + ( + Arc::new(Field::new("inner_int", DataType::Int32, false)), + inner_int_col, + ), + ( + Arc::new(Field::new("inner_str", DataType::Utf8, true)), + inner_str_col, + ), + ]; + let struct_array = StructArray::from(fields_arrays); + let batch = RecordBatch::try_new(arrow_schema, vec![Arc::new(struct_array)]) + .expect("Failed to create RecordBatch"); + let avro_json_string = to_avro_json_schema(&batch, "NestedRecord") + .expect("Failed to generate Avro JSON schema"); + let actual_json: Value = + serde_json::from_str(&avro_json_string).expect("Failed to parse Avro JSON schema"); + let expected_json = json!({ + "type": "record", + "name": "NestedRecord", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "my_struct", + "doc": null, + "type": [ + "null", + { + "type": "record", + "name": "my_struct", + "aliases": [], + "doc": null, + "logicalType": null, + "fields": [ + { + "name": "inner_int", + "type": "int", + "doc": null, + }, + { + "name": "inner_str", + "doc": null, + "type": [ + "null", + "string", + ] + } + ] + } + ] + } + ] + }); + assert_eq!(actual_json, expected_json, "Avro JSON schema mismatch"); + } +} diff --git a/arrow-avro/src/writer/vlq.rs b/arrow-avro/src/writer/vlq.rs new file mode 100644 index 000000000000..4cf26e23856d --- /dev/null +++ b/arrow-avro/src/writer/vlq.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Encoder for zig-zag encoded variable length integers +/// +/// This complements the VLQ decoding logic used by Avro. Zig-zag encoding maps signed integers +/// to unsigned integers so that small magnitudes (both positive and negative) produce smaller varints. +/// After zig-zag encoding, values are encoded as a series of bytes where the lower 7 bits are data +/// and the high bit indicates if another byte follows. +/// +/// See also: +/// +/// +#[derive(Debug, Default)] +pub struct VLQEncoder; + +impl VLQEncoder { + /// Encode a signed 64-bit integer `value` into `output` using Avro's zig-zag varint encoding. + /// + /// Zig-zag encoding: + /// ```text + /// encoded = (value << 1) ^ (value >> 63) + /// ``` + /// + /// Then `encoded` is written as a variable-length integer (varint): + /// - Extract 7 bits at a time + /// - If more bits remain, set the MSB of the current byte to 1 and continue + /// - Otherwise, MSB is 0 and this is the last byte + pub fn long(&mut self, value: i64, output: &mut Vec) { + let zigzag = ((value << 1) ^ (value >> 63)) as u64; + self.encode_varint(zigzag, output); + } + + fn encode_varint(&self, mut val: u64, output: &mut Vec) { + while (val & !0x7F) != 0 { + output.push(((val & 0x7F) as u8) | 0x80); + val >>= 7; + } + output.push(val as u8); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn decode_varint(buf: &mut &[u8]) -> Option { + let mut value = 0_u64; + for i in 0..10 { + let b = buf.get(i).copied()?; + let lower_7 = (b & 0x7F) as u64; + value |= lower_7 << (7 * i); + if b & 0x80 == 0 { + *buf = &buf[i + 1..]; + return Some(value); + } + } + None // more than 10 bytes or not terminated properly + } + + fn decode_zigzag(val: u64) -> i64 { + ((val >> 1) as i64) ^ -((val & 1) as i64) + } + + fn decode_long(buf: &mut &[u8]) -> Option { + let val = decode_varint(buf)?; + Some(decode_zigzag(val)) + } + + fn round_trip(value: i64) { + let mut encoder = VLQEncoder; + let mut buf = Vec::new(); + encoder.long(value, &mut buf); + let mut slice = buf.as_slice(); + let decoded = decode_long(&mut slice).expect("Failed to decode value"); + assert_eq!(decoded, value, "Round-trip mismatch for value {}", value); + assert!(slice.is_empty(), "Not all bytes consumed"); + } + + #[test] + fn test_round_trip() { + round_trip(0); + round_trip(1); + round_trip(-1); + round_trip(12345678); + round_trip(-12345678); + round_trip(i64::MAX); + round_trip(i64::MIN); + } + + #[test] + fn test_random_values() { + use rand::Rng; + let mut rng = rand::thread_rng(); + for _ in 0..1000 { + let val: i64 = rng.gen(); + round_trip(val); + } + } +}