Skip to content

Commit

Permalink
Added visit_pyarrow dispatch for pyarrow field
Browse files Browse the repository at this point in the history
Signed-off-by: Christian Molina <[email protected]>
  • Loading branch information
DevChrisCross committed Jan 11, 2025
1 parent f1b8f50 commit cce35c7
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 50 deletions.
46 changes: 20 additions & 26 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@
from pyiceberg.conversions import to_bytes
from pyiceberg.exceptions import (
ResolveError,
UnsupportedPyArrowIntegerTypeException,
UnsupportedPyArrowTimestampTypeException,
UnsupportedPyArrowTypeException,
)
from pyiceberg.expressions import AlwaysTrue, BooleanExpression, BoundIsNaN, BoundIsNull, BoundTerm, Not, Or
Expand Down Expand Up @@ -965,13 +963,7 @@ def _(obj: pa.Schema, visitor: PyArrowSchemaVisitor[T]) -> T:

@visit_pyarrow.register(pa.StructType)
def _(obj: pa.StructType, visitor: PyArrowSchemaVisitor[T]) -> T:
results = []

for field in obj:
visitor.before_field(field)
result = visit_pyarrow(field.type, visitor)
results.append(visitor.field(field, result))
visitor.after_field(field)
results = [visit_pyarrow(field, visitor) for field in obj]

return visitor.struct(obj, results)

Expand Down Expand Up @@ -1009,6 +1001,20 @@ def _(obj: pa.DictionaryType, visitor: PyArrowSchemaVisitor[T]) -> T:
return visit_pyarrow(obj.value_type, visitor)


@visit_pyarrow.register(pa.Field)
def _(obj: pa.Field, visitor: PyArrowSchemaVisitor[T]) -> T:
field_type = obj.type

visitor.before_field(obj)
try:
result = visit_pyarrow(field_type, visitor)
except TypeError as e:
raise UnsupportedPyArrowTypeException(obj, f"Column '{obj.name}' has an unsupported type: {field_type}") from e
visitor.after_field(obj)

return visitor.field(obj, result)


@visit_pyarrow.register(pa.DataType)
def _(obj: pa.DataType, visitor: PyArrowSchemaVisitor[T]) -> T:
if pa.types.is_nested(obj):
Expand Down Expand Up @@ -1148,12 +1154,6 @@ def map(self, map_type: pa.MapType, key_result: IcebergType, value_result: Icebe
return MapType(key_id, key_result, value_id, value_result, value_required=not value_field.nullable)

def primitive(self, primitive: pa.DataType) -> PrimitiveType:
field_name = None
unsupported_prefix = "Unsupported"
if len(self._field_names) > 0:
field_name = self._field_names[-1]
unsupported_prefix = f"Column '{field_name}' has an unsupported"

if pa.types.is_boolean(primitive):
return BooleanType()
elif pa.types.is_integer(primitive):
Expand All @@ -1164,7 +1164,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
return LongType()
else:
# Does not exist (yet)
raise UnsupportedPyArrowIntegerTypeException(self._field, f"{unsupported_prefix} integer type: {primitive}")
raise TypeError(f"Unsupported integer type: {primitive}")
elif pa.types.is_float32(primitive):
return FloatType()
elif pa.types.is_float64(primitive):
Expand All @@ -1187,17 +1187,11 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
if self._downcast_ns_timestamp_to_us:
logger.warning("Iceberg does not yet support 'ns' timestamp precision. Downcasting to 'us'.")
else:
column_name_message = ""
if field_name:
column_name_message = f", making the column '{field_name}' unsupported"
raise UnsupportedPyArrowTimestampTypeException(
self._field,
f"Iceberg does not yet support 'ns' timestamp precision{column_name_message}. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write.",
raise TypeError(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write.",
)
else:
raise UnsupportedPyArrowTimestampTypeException(
self._field, f"{unsupported_prefix} precision for timestamp type: {primitive.unit}"
)
raise TypeError(f"Unsupported precision for timestamp type: {primitive.unit}")

if primitive.tz in UTC_ALIASES:
return TimestamptzType()
Expand All @@ -1210,7 +1204,7 @@ def primitive(self, primitive: pa.DataType) -> PrimitiveType:
primitive = cast(pa.FixedSizeBinaryType, primitive)
return FixedType(primitive.byte_width)

raise UnsupportedPyArrowTypeException(self._field, f"{unsupported_prefix} type: {primitive}")
raise TypeError(f"Unsupported type: {primitive}")

def before_field(self, field: pa.Field) -> None:
self._field_names.append(field.name)
Expand Down
70 changes: 46 additions & 24 deletions tests/io/test_pyarrow_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pyarrow as pa
import pytest

from pyiceberg.exceptions import UnsupportedPyArrowTimestampTypeException, UnsupportedPyArrowTypeException
from pyiceberg.exceptions import UnsupportedPyArrowTypeException
from pyiceberg.expressions import (
And,
BoundEqualTo,
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_pyarrow_decimal256_to_iceberg() -> None:
precision = 26
scale = 20
pyarrow_type = pa.decimal256(precision, scale)
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: decimal256(26, 20)")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: decimal256(26, 20)")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())


Expand Down Expand Up @@ -151,16 +151,16 @@ def test_pyarrow_date32_to_iceberg() -> None:

def test_pyarrow_date64_to_iceberg() -> None:
pyarrow_type = pa.date64()
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: date64")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: date64")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())


def test_pyarrow_time32_to_iceberg() -> None:
pyarrow_type = pa.time32("ms")
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: time32[ms]")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: time32[ms]")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())
pyarrow_type = pa.time32("s")
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: time32[s]")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: time32[s]")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())


Expand All @@ -173,7 +173,7 @@ def test_pyarrow_time64_us_to_iceberg() -> None:

def test_pyarrow_time64_ns_to_iceberg() -> None:
pyarrow_type = pa.time64("ns")
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: time64[ns]")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: time64[ns]")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())


Expand All @@ -189,7 +189,7 @@ def test_pyarrow_timestamp_to_iceberg(precision: str) -> None:
def test_pyarrow_timestamp_invalid_units() -> None:
pyarrow_type = pa.timestamp(unit="ns")
with pytest.raises(
UnsupportedPyArrowTimestampTypeException,
TypeError,
match=re.escape(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
),
Expand All @@ -211,7 +211,7 @@ def test_pyarrow_timestamp_tz_to_iceberg() -> None:
def test_pyarrow_timestamp_tz_invalid_units() -> None:
pyarrow_type = pa.timestamp(unit="ns", tz="UTC")
with pytest.raises(
UnsupportedPyArrowTimestampTypeException,
TypeError,
match=re.escape(
"Iceberg does not yet support 'ns' timestamp precision. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
),
Expand All @@ -221,7 +221,7 @@ def test_pyarrow_timestamp_tz_invalid_units() -> None:

def test_pyarrow_timestamp_tz_invalid_tz() -> None:
pyarrow_type = pa.timestamp(unit="us", tz="US/Pacific")
with pytest.raises(UnsupportedPyArrowTypeException, match=re.escape("Unsupported type: timestamp[us, tz=US/Pacific]")):
with pytest.raises(TypeError, match=re.escape("Unsupported type: timestamp[us, tz=US/Pacific]")):
visit_pyarrow(pyarrow_type, _ConvertToIceberg())


Expand Down Expand Up @@ -608,33 +608,55 @@ def test_pyarrow_schema_unsupported_type() -> None:
) as exc_info:
visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
assert exc_info.value.field == lat_field
exception_cause = exc_info.value.__cause__
assert isinstance(exception_cause, TypeError)
assert "Unsupported type: decimal256(20, 26)" in exception_cause.args[0]

quux_field = pa.field(
"quux",
pa.map_(
pa.field("key", pa.string(), nullable=False, metadata={"PARQUET:field_id": "7"}),
pa.field(
"value",
pa.map_(pa.field("key", pa.string(), nullable=False), pa.field("value", pa.decimal256(2, 3))),
nullable=False,
),
),
nullable=False,
metadata={"PARQUET:field_id": "6", "doc": "quux doc"},
)
schema = pa.schema(
[
pa.field("foo", pa.string(), nullable=False),
quux_field,
]
)
with pytest.raises(
UnsupportedPyArrowTypeException,
match=re.escape("Column 'quux' has an unsupported type: map<string, map<string, decimal256(2, 3)>>"),
) as exc_info:
visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
assert exc_info.value.field == quux_field
exception_cause = exc_info.value.__cause__
assert isinstance(exception_cause, TypeError)
assert "Unsupported type: decimal256(2, 3)" in exception_cause.args[0]

foo_field = pa.field("foo", pa.timestamp(unit="ns"), nullable=False)
schema = pa.schema(
[
foo_field,
pa.field(
"location",
pa.large_list(
pa.struct(
[
pa.field("latitude", pa.float32(), nullable=False),
pa.field("longitude", pa.float32(), nullable=False),
]
),
),
nullable=False,
),
pa.field("bar", pa.int32(), nullable=False),
]
)
with pytest.raises(
UnsupportedPyArrowTypeException,
match=re.escape(
"Iceberg does not yet support 'ns' timestamp precision, making the column 'foo' unsupported. Use 'downcast-ns-timestamp-to-us-on-write' configuration property to automatically downcast 'ns' to 'us' on write."
),
match=re.escape("Column 'foo' has an unsupported type: timestamp[ns]"),
) as exc_info:
visit_pyarrow(schema, _ConvertToIcebergWithoutIDs())
assert exc_info.value.field == foo_field
exception_cause = exc_info.value.__cause__
assert isinstance(exception_cause, TypeError)
assert "Iceberg does not yet support 'ns' timestamp precision" in exception_cause.args[0]


def test_pyarrow_schema_round_trip_ensure_large_types_and_then_small_types(pyarrow_schema_nested_without_ids: pa.Schema) -> None:
Expand Down

0 comments on commit cce35c7

Please sign in to comment.