diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 06b9ad92677a2d..75e8fe1d6f7a6f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -37,7 +37,7 @@ gen_database_key, ) from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField -from datahub.metadata.schema_classes import RecordTypeClass +from datahub.metadata.schema_classes import MapTypeClass, RecordTypeClass from datahub.utilities.hive_schema_to_avro import get_avro_schema_for_hive_column from datahub.utilities.sqlalchemy_type_converter import ( MapType, @@ -46,7 +46,9 @@ logger = logging.getLogger(__name__) +assert STRUCT, "required type modules are not available" register_custom_type(STRUCT, RecordTypeClass) +register_custom_type(MapType, MapTypeClass) class CustomAthenaRestDialect(AthenaRestDialect): diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py index be03858ec3ef91..494d02fc3441ad 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/sql_common.py @@ -90,7 +90,6 @@ from datahub.utilities.lossy_collections import LossyList from datahub.utilities.registries.domain_registry import DomainRegistry from datahub.utilities.sqlalchemy_query_combiner import SQLAlchemyQueryCombinerReport -from datahub.utilities.sqlalchemy_type_converter import MapType if TYPE_CHECKING: from datahub.ingestion.source.ge_data_profiler import ( @@ -140,6 +139,7 @@ class SqlWorkUnit(MetadataWorkUnit): _field_type_mapping: Dict[Type[TypeEngine], Type] = { + # Note: to add dialect-specific types to this mapping, use the `register_custom_type` function. types.Integer: NumberTypeClass, types.Numeric: NumberTypeClass, types.Boolean: BooleanTypeClass, @@ -156,8 +156,6 @@ class SqlWorkUnit(MetadataWorkUnit): types.DATETIME: TimeTypeClass, types.TIMESTAMP: TimeTypeClass, types.JSON: RecordTypeClass, - # additional type definitions that are used by the Athena source - MapType: MapTypeClass, # type: ignore # Because the postgresql dialect is used internally by many other dialects, # we add some postgres types here. This is ok to do because the postgresql # dialect is built-in to sqlalchemy. diff --git a/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py index 1d5ec5dae35190..5d2fc6872c7bd9 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py +++ b/metadata-ingestion/src/datahub/utilities/sqlalchemy_type_converter.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List, Optional, Type, Union from sqlalchemy import types -from sqlalchemy_bigquery import STRUCT from datahub.ingestion.extractor.schema_util import avro_schema_to_mce_fields from datahub.metadata.com.linkedin.pegasus2avro.schema import SchemaField @@ -12,6 +11,12 @@ logger = logging.getLogger(__name__) +try: + # This is used for both BigQuery and Athena. + from sqlalchemy_bigquery import STRUCT +except ImportError: + STRUCT = None + class MapType(types.TupleType): # Wrapper class around SQLalchemy's TupleType to increase compatibility with DataHub @@ -42,7 +47,9 @@ def get_avro_type( ) -> Dict[str, Any]: """Determines the concrete AVRO schema type for a SQLalchemy-typed column""" - if type(column_type) in cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys(): + if isinstance( + column_type, tuple(cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE.keys()) + ): return { "type": cls.PRIMITIVE_SQL_ALCHEMY_TYPE_TO_AVRO_TYPE[type(column_type)], "native_data_type": str(column_type), @@ -88,7 +95,7 @@ def get_avro_type( "key_type": cls.get_avro_type(column_type=key_type, nullable=nullable), "key_native_data_type": str(key_type), } - if isinstance(column_type, STRUCT): + if STRUCT and isinstance(column_type, STRUCT): fields = [] for field_def in column_type._STRUCT_fields: field_name, field_type = field_def