diff --git a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py index 71cfd0268ee6b5..6f7decc79b1df2 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py +++ b/metadata-ingestion/src/datahub/ingestion/source/sql/athena.py @@ -26,6 +26,7 @@ platform_name, support_status, ) +from datahub.ingestion.api.source import StructuredLogLevel from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.aws.s3_util import make_s3_urn from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes @@ -35,6 +36,7 @@ register_custom_type, ) from datahub.ingestion.source.sql.sql_config import SQLCommonConfig, make_sqlalchemy_uri +from datahub.ingestion.source.sql.sql_report import SQLSourceReport from datahub.ingestion.source.sql.sql_utils import ( add_table_to_schema_container, gen_database_container, @@ -48,6 +50,15 @@ get_schema_fields_for_sqlalchemy_column, ) +try: + from typing_extensions import override +except ImportError: + _F = typing.TypeVar("_F", bound=typing.Callable[..., typing.Any]) + + def override(f: _F, /) -> _F: # noqa: F811 + return f + + logger = logging.getLogger(__name__) assert STRUCT, "required type modules are not available" @@ -322,12 +333,15 @@ class AthenaSource(SQLAlchemySource): - Profiling when enabled. """ - table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {} + config: AthenaConfig + report: SQLSourceReport def __init__(self, config, ctx): super().__init__(config, ctx, "athena") self.cursor: Optional[BaseCursor] = None + self.table_partition_cache: Dict[str, Dict[str, Partitionitem]] = {} + @classmethod def create(cls, config_dict, ctx): config = AthenaConfig.parse_obj(config_dict) @@ -452,6 +466,7 @@ def add_table_to_schema_container( ) # It seems like database/schema filter in the connection string does not work and this to work around that + @override def get_schema_names(self, inspector: Inspector) -> List[str]: athena_config = typing.cast(AthenaConfig, self.config) schemas = inspector.get_schema_names() @@ -459,34 +474,42 @@ def get_schema_names(self, inspector: Inspector) -> List[str]: return [schema for schema in schemas if schema == athena_config.database] return schemas - # Overwrite to get partitions + @classmethod + def _casted_partition_key(cls, key: str) -> str: + # We need to cast the partition keys to a VARCHAR, since otherwise + # Athena may throw an error during concatenation / comparison. + return f"CAST({key} as VARCHAR)" + + @override def get_partitions( self, inspector: Inspector, schema: str, table: str - ) -> List[str]: - partitions = [] - - athena_config = typing.cast(AthenaConfig, self.config) - - if not athena_config.extract_partitions: - return [] + ) -> Optional[List[str]]: + if not self.config.extract_partitions: + return None if not self.cursor: - return [] + return None metadata: AthenaTableMetadata = self.cursor.get_table_metadata( table_name=table, schema_name=schema ) - if metadata.partition_keys: - for key in metadata.partition_keys: - if key.name: - partitions.append(key.name) - - if not partitions: - return [] + partitions = [] + for key in metadata.partition_keys: + if key.name: + partitions.append(key.name) + if not partitions: + return [] - # We create an artiificaial concatenated partition key to be able to query max partition easier - part_concat = "|| '-' ||".join(partitions) + with self.report.report_exc( + message="Failed to extract partition details", + context=f"{schema}.{table}", + level=StructuredLogLevel.WARN, + ): + # We create an artifical concatenated partition key to be able to query max partition easier + part_concat = " || '-' || ".join( + self._casted_partition_key(key) for key in partitions + ) max_partition_query = f'select {",".join(partitions)} from "{schema}"."{table}$partitions" where {part_concat} = (select max({part_concat}) from "{schema}"."{table}$partitions")' ret = self.cursor.execute(max_partition_query) max_partition: Dict[str, str] = {} @@ -500,9 +523,8 @@ def get_partitions( partitions=partitions, max_partition=max_partition, ) - return partitions - return [] + return partitions # Overwrite to modify the creation of schema fields def get_schema_fields_for_column( @@ -551,7 +573,9 @@ def generate_partition_profiler_query( if partition and partition.max_partition: max_partition_filters = [] for key, value in partition.max_partition.items(): - max_partition_filters.append(f"CAST({key} as VARCHAR) = '{value}'") + max_partition_filters.append( + f"{self._casted_partition_key(key)} = '{value}'" + ) max_partition = str(partition.max_partition) return ( max_partition, diff --git a/metadata-ingestion/tests/unit/test_athena_source.py b/metadata-ingestion/tests/unit/test_athena_source.py index 875cf3800daf88..f8b6220d182735 100644 --- a/metadata-ingestion/tests/unit/test_athena_source.py +++ b/metadata-ingestion/tests/unit/test_athena_source.py @@ -93,7 +93,8 @@ def test_athena_get_table_properties(): "CreateTime": datetime.now(), "LastAccessTime": datetime.now(), "PartitionKeys": [ - {"Name": "testKey", "Type": "string", "Comment": "testComment"} + {"Name": "year", "Type": "string", "Comment": "testComment"}, + {"Name": "month", "Type": "string", "Comment": "testComment"}, ], "Parameters": { "comment": "testComment", @@ -112,8 +113,18 @@ def test_athena_get_table_properties(): response=table_metadata ) + # Mock partition query results + mock_cursor.execute.return_value.description = [ + ["year"], + ["month"], + ] + mock_cursor.execute.return_value.__iter__.return_value = [["2023", "12"]] + ctx = PipelineContext(run_id="test") source = AthenaSource(config=config, ctx=ctx) + source.cursor = mock_cursor + + # Test table properties description, custom_properties, location = source.get_table_properties( inspector=mock_inspector, table=table, schema=schema ) @@ -124,13 +135,35 @@ def test_athena_get_table_properties(): "last_access_time": "2020-04-14 07:00:00", "location": "s3://testLocation", "outputformat": "testOutputFormat", - "partition_keys": '[{"name": "testKey", "type": "string", "comment": "testComment"}]', + "partition_keys": '[{"name": "year", "type": "string", "comment": "testComment"}, {"name": "month", "type": "string", "comment": "testComment"}]', "serde.serialization.lib": "testSerde", "table_type": "testType", } - assert location == make_s3_urn("s3://testLocation", "PROD") + # Test partition functionality + partitions = source.get_partitions( + inspector=mock_inspector, schema=schema, table=table + ) + assert partitions == ["year", "month"] + + # Verify the correct SQL query was generated for partitions + expected_query = """\ +select year,month from "test_schema"."test_table$partitions" \ +where CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR) = \ +(select max(CAST(year as VARCHAR) || '-' || CAST(month as VARCHAR)) \ +from "test_schema"."test_table$partitions")""" + mock_cursor.execute.assert_called_once() + actual_query = mock_cursor.execute.call_args[0][0] + assert actual_query == expected_query + + # Verify partition cache was populated correctly + assert source.table_partition_cache[schema][table].partitions == partitions + assert source.table_partition_cache[schema][table].max_partition == { + "year": "2023", + "month": "12", + } + def test_get_column_type_simple_types(): assert isinstance( @@ -214,3 +247,9 @@ def test_column_type_complex_combination(): assert isinstance( result._STRUCT_fields[2][1].item_type._STRUCT_fields[1][1], types.String ) + + +def test_casted_partition_key(): + from datahub.ingestion.source.sql.athena import AthenaSource + + assert AthenaSource._casted_partition_key("test_col") == "CAST(test_col as VARCHAR)"