diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index 3c6202cc7cbfa..0e986acc81add 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -272,7 +272,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self.report.set_ingestion_stage("*", QUERIES_EXTRACTION) - queries_extractor = BigQueryQueriesExtractor( + with BigQueryQueriesExtractor( connection=self.config.get_bigquery_client(), schema_api=self.bq_schema_extractor.schema_api, config=BigQueryQueriesExtractorConfig( @@ -288,9 +288,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: identifiers=self.identifiers, schema_resolver=self.sql_parser_schema_resolver, discovered_tables=self.bq_schema_extractor.table_refs, - ) - self.report.queries_extractor = queries_extractor.report - yield from queries_extractor.get_workunits_internal() + ) as queries_extractor: + self.report.queries_extractor = queries_extractor.report + yield from queries_extractor.get_workunits_internal() + else: if self.config.include_usage_statistics: yield from self.usage_extractor.get_usage_workunits( diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py index ed27aae19ce96..47f21c9f32353 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_queries.py @@ -88,3 +88,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: def get_report(self) -> BigQueryQueriesSourceReport: return self.report + + def close(self) -> None: + self.queries_extractor.close() + self.connection.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py index b4a443673b9a9..afaaaf51964f8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/queries_extractor.py @@ -13,6 +13,7 @@ BaseTimeWindowConfig, get_time_bucket, ) +from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.source import SourceReport from datahub.ingestion.api.source_helpers import auto_workunit from datahub.ingestion.api.workunit import MetadataWorkUnit @@ -114,7 +115,7 @@ class BigQueryQueriesExtractorConfig(BigQueryBaseConfig): ) -class BigQueryQueriesExtractor: +class BigQueryQueriesExtractor(Closeable): """ Extracts query audit log and generates usage/lineage/operation workunits. @@ -181,6 +182,7 @@ def __init__( is_allowed_table=self.is_allowed_table, format_queries=False, ) + self.report.sql_aggregator = self.aggregator.report self.report.num_discovered_tables = ( len(self.discovered_tables) if self.discovered_tables else None @@ -273,12 +275,14 @@ def get_workunits_internal( self.report.num_unique_queries = len(queries_deduped) logger.info(f"Found {self.report.num_unique_queries} unique queries") - with self.report.audit_log_load_timer: + with self.report.audit_log_load_timer, queries_deduped: i = 0 for _, query_instances in queries_deduped.items(): for query in query_instances.values(): if i > 0 and i % 10000 == 0: - logger.info(f"Added {i} query log entries to SQL aggregator") + logger.info( + f"Added {i} query log equeries_dedupedntries to SQL aggregator" + ) if self.report.sql_aggregator: logger.info(self.report.sql_aggregator.as_string()) @@ -287,6 +291,11 @@ def get_workunits_internal( yield from auto_workunit(self.aggregator.gen_metadata()) + if not use_cached_audit_log: + queries.close() + shared_connection.close() + audit_log_file.unlink(missing_ok=True) + def deduplicate_queries( self, queries: FileBackedList[ObservedQuery] ) -> FileBackedDict[Dict[int, ObservedQuery]]: @@ -404,6 +413,9 @@ def _parse_audit_log_row(self, row: BigQueryJob) -> ObservedQuery: return entry + def close(self) -> None: + self.aggregator.close() + def _extract_query_text(row: BigQueryJob) -> str: # We wrap select statements in a CTE to make them parseable as DML statement. diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py index 4df64c80bad8a..53f9383ec02a7 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage_v2.py @@ -5,6 +5,7 @@ import redshift_connector from datahub.emitter import mce_builder +from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.redshift.config import LineageMode, RedshiftConfig @@ -39,7 +40,7 @@ logger = logging.getLogger(__name__) -class RedshiftSqlLineageV2: +class RedshiftSqlLineageV2(Closeable): # does lineage and usage based on SQL parsing. def __init__( @@ -56,6 +57,7 @@ def __init__( self.context = context self.database = database + self.aggregator = SqlParsingAggregator( platform=self.platform, platform_instance=self.config.platform_instance, @@ -436,3 +438,6 @@ def generate(self) -> Iterable[MetadataWorkUnit]: message="Unexpected error(s) while attempting to extract lineage from SQL queries. See the full logs for more details.", context=f"Query Parsing Failures: {self.aggregator.report.observed_query_parse_failures}", ) + + def close(self) -> None: + self.aggregator.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index a9fc9ab8f3e99..76030cea98494 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -451,24 +451,23 @@ def _extract_metadata( ) if self.config.use_lineage_v2: - lineage_extractor = RedshiftSqlLineageV2( + with RedshiftSqlLineageV2( config=self.config, report=self.report, context=self.ctx, database=database, redundant_run_skip_handler=self.redundant_lineage_run_skip_handler, - ) - - yield from lineage_extractor.aggregator.register_schemas_from_stream( - self.process_schemas(connection, database) - ) + ) as lineage_extractor: + yield from lineage_extractor.aggregator.register_schemas_from_stream( + self.process_schemas(connection, database) + ) - self.report.report_ingestion_stage_start(LINEAGE_EXTRACTION) - yield from self.extract_lineage_v2( - connection=connection, - database=database, - lineage_extractor=lineage_extractor, - ) + self.report.report_ingestion_stage_start(LINEAGE_EXTRACTION) + yield from self.extract_lineage_v2( + connection=connection, + database=database, + lineage_extractor=lineage_extractor, + ) all_tables = self.get_all_tables() else: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py index d39e95a884dbc..a9f454cfd3cdb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_connection.py @@ -18,6 +18,7 @@ from datahub.configuration.connection_resolver import auto_connection_resolver from datahub.configuration.oauth import OAuthConfiguration, OAuthIdentityProvider from datahub.configuration.validate_field_rename import pydantic_renamed_field +from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.source.snowflake.constants import ( CLIENT_PREFETCH_THREADS, CLIENT_SESSION_KEEP_ALIVE, @@ -364,7 +365,7 @@ def get_connection(self) -> "SnowflakeConnection": ) from e -class SnowflakeConnection: +class SnowflakeConnection(Closeable): _connection: NativeSnowflakeConnection def __init__(self, connection: NativeSnowflakeConnection): diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 1445d02aa49db..e11073d77b46e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -1,3 +1,4 @@ +import contextlib import dataclasses import functools import json @@ -17,6 +18,7 @@ BaseTimeWindowConfig, BucketDuration, ) +from datahub.ingestion.api.closeable import Closeable from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.report import Report from datahub.ingestion.api.source import Source, SourceReport @@ -121,7 +123,7 @@ class SnowflakeQueriesSourceReport(SourceReport): queries_extractor: Optional[SnowflakeQueriesExtractorReport] = None -class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin): +class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable): def __init__( self, connection: SnowflakeConnection, @@ -143,28 +145,33 @@ def __init__( self._structured_report = structured_report - self.aggregator = SqlParsingAggregator( - platform=self.identifiers.platform, - platform_instance=self.identifiers.identifier_config.platform_instance, - env=self.identifiers.identifier_config.env, - schema_resolver=schema_resolver, - graph=graph, - eager_graph_load=False, - generate_lineage=self.config.include_lineage, - generate_queries=self.config.include_queries, - generate_usage_statistics=self.config.include_usage_statistics, - generate_query_usage_statistics=self.config.include_query_usage_statistics, - usage_config=BaseUsageConfig( - bucket_duration=self.config.window.bucket_duration, - start_time=self.config.window.start_time, - end_time=self.config.window.end_time, - user_email_pattern=self.config.user_email_pattern, - # TODO make the rest of the fields configurable - ), - generate_operations=self.config.include_operations, - is_temp_table=self.is_temp_table, - is_allowed_table=self.is_allowed_table, - format_queries=False, + # The exit stack helps ensure that we close all the resources we open. + self._exit_stack = contextlib.ExitStack() + + self.aggregator: SqlParsingAggregator = self._exit_stack.enter_context( + SqlParsingAggregator( + platform=self.identifiers.platform, + platform_instance=self.identifiers.identifier_config.platform_instance, + env=self.identifiers.identifier_config.env, + schema_resolver=schema_resolver, + graph=graph, + eager_graph_load=False, + generate_lineage=self.config.include_lineage, + generate_queries=self.config.include_queries, + generate_usage_statistics=self.config.include_usage_statistics, + generate_query_usage_statistics=self.config.include_query_usage_statistics, + usage_config=BaseUsageConfig( + bucket_duration=self.config.window.bucket_duration, + start_time=self.config.window.start_time, + end_time=self.config.window.end_time, + user_email_pattern=self.config.user_email_pattern, + # TODO make the rest of the fields configurable + ), + generate_operations=self.config.include_operations, + is_temp_table=self.is_temp_table, + is_allowed_table=self.is_allowed_table, + format_queries=False, + ) ) self.report.sql_aggregator = self.aggregator.report @@ -248,6 +255,10 @@ def get_workunits_internal( self.aggregator.add(query) yield from auto_workunit(self.aggregator.gen_metadata()) + if not use_cached_audit_log: + queries.close() + shared_connection.close() + audit_log_file.unlink(missing_ok=True) def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: # Derived from _populate_external_lineage_from_copy_history. @@ -426,6 +437,9 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery: ) return entry + def close(self) -> None: + self._exit_stack.close() + class SnowflakeQueriesSource(Source): def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig): @@ -468,6 +482,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: def get_report(self) -> SnowflakeQueriesSourceReport: return self.report + def close(self) -> None: + self.connection.close() + self.queries_extractor.close() + # Make sure we don't try to generate too much info for a single query. _MAX_TABLES_PER_QUERY = 20 diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 0d7881f36554d..dd7f73268fdc4 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -1,3 +1,4 @@ +import contextlib import functools import json import logging @@ -149,7 +150,12 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): cached_domains=[k for k in self.config.domain], graph=self.ctx.graph ) - self.connection = self.config.get_connection() + # The exit stack helps ensure that we close all the resources we open. + self._exit_stack = contextlib.ExitStack() + + self.connection: SnowflakeConnection = self._exit_stack.enter_context( + self.config.get_connection() + ) # For database, schema, tables, views, etc self.data_dictionary = SnowflakeDataDictionary(connection=self.connection) @@ -157,25 +163,27 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.aggregator: Optional[SqlParsingAggregator] = None if self.config.use_queries_v2 or self.config.include_table_lineage: - self.aggregator = SqlParsingAggregator( - platform=self.identifiers.platform, - platform_instance=self.config.platform_instance, - env=self.config.env, - graph=self.ctx.graph, - eager_graph_load=( - # If we're ingestion schema metadata for tables/views, then we will populate - # schemas into the resolver as we go. We only need to do a bulk fetch - # if we're not ingesting schema metadata as part of ingestion. - not ( - self.config.include_technical_schema - and self.config.include_tables - and self.config.include_views - ) - and not self.config.lazy_schema_resolver - ), - generate_usage_statistics=False, - generate_operations=False, - format_queries=self.config.format_sql_queries, + self.aggregator = self._exit_stack.enter_context( + SqlParsingAggregator( + platform=self.identifiers.platform, + platform_instance=self.config.platform_instance, + env=self.config.env, + graph=self.ctx.graph, + eager_graph_load=( + # If we're ingestion schema metadata for tables/views, then we will populate + # schemas into the resolver as we go. We only need to do a bulk fetch + # if we're not ingesting schema metadata as part of ingestion. + not ( + self.config.include_technical_schema + and self.config.include_tables + and self.config.include_views + ) + and not self.config.lazy_schema_resolver + ), + generate_usage_statistics=False, + generate_operations=False, + format_queries=self.config.format_sql_queries, + ) ) self.report.sql_aggregator = self.aggregator.report @@ -191,14 +199,16 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): pipeline_name=self.ctx.pipeline_name, run_id=self.ctx.run_id, ) - self.lineage_extractor = SnowflakeLineageExtractor( - config, - self.report, - connection=self.connection, - filters=self.filters, - identifiers=self.identifiers, - redundant_run_skip_handler=redundant_lineage_run_skip_handler, - sql_aggregator=self.aggregator, + self.lineage_extractor = self._exit_stack.enter_context( + SnowflakeLineageExtractor( + config, + self.report, + connection=self.connection, + filters=self.filters, + identifiers=self.identifiers, + redundant_run_skip_handler=redundant_lineage_run_skip_handler, + sql_aggregator=self.aggregator, + ) ) self.usage_extractor: Optional[SnowflakeUsageExtractor] = None @@ -213,13 +223,15 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): pipeline_name=self.ctx.pipeline_name, run_id=self.ctx.run_id, ) - self.usage_extractor = SnowflakeUsageExtractor( - config, - self.report, - connection=self.connection, - filter=self.filters, - identifiers=self.identifiers, - redundant_run_skip_handler=redundant_usage_run_skip_handler, + self.usage_extractor = self._exit_stack.enter_context( + SnowflakeUsageExtractor( + config, + self.report, + connection=self.connection, + filter=self.filters, + identifiers=self.identifiers, + redundant_run_skip_handler=redundant_usage_run_skip_handler, + ) ) self.profiling_state_handler: Optional[ProfilingHandler] = None @@ -444,10 +456,6 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: self._snowflake_clear_ocsp_cache() - self.connection = self.config.get_connection() - if self.connection is None: - return - self.inspect_session_metadata(self.connection) snowsight_url_builder = None @@ -513,7 +521,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: schema_resolver = self.aggregator._schema_resolver - queries_extractor = SnowflakeQueriesExtractor( + queries_extractor: SnowflakeQueriesExtractor = SnowflakeQueriesExtractor( connection=self.connection, config=SnowflakeQueriesExtractorConfig( window=self.config, @@ -535,6 +543,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: # it should be pretty straightforward to refactor this and only initialize the aggregator once. self.report.queries_extractor = queries_extractor.report yield from queries_extractor.get_workunits_internal() + queries_extractor.close() else: if self.config.include_table_lineage and self.lineage_extractor: @@ -723,7 +732,4 @@ def _snowflake_clear_ocsp_cache(self) -> None: def close(self) -> None: super().close() StatefulIngestionSourceBase.close(self) - if self.lineage_extractor: - self.lineage_extractor.close() - if self.usage_extractor: - self.usage_extractor.close() + self._exit_stack.close() diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 5f2709fe42660..0b7ad14a8c1b4 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -275,6 +275,8 @@ class SqlAggregatorReport(Report): tool_meta_report: Optional[ToolMetaExtractorReport] = None def compute_stats(self) -> None: + if self._aggregator._closed: + return self.schema_resolver_count = self._aggregator._schema_resolver.schema_count() self.num_unique_query_fingerprints = len(self._aggregator._query_map) @@ -345,6 +347,7 @@ def __init__( # The exit stack helps ensure that we close all the resources we open. self._exit_stack = contextlib.ExitStack() + self._closed: bool = False # Set up the schema resolver. self._schema_resolver: SchemaResolver @@ -456,12 +459,16 @@ def __init__( shared_connection=self._shared_connection, tablename="query_usage_counts", ) + self._exit_stack.push(self._query_usage_counts) # Tool Extractor self._tool_meta_extractor = ToolMetaExtractor() self.report.tool_meta_report = self._tool_meta_extractor.report def close(self) -> None: + # Compute stats once before closing connections + self.report.compute_stats() + self._closed = True self._exit_stack.close() @property diff --git a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py index 9290100b0c521..ef846f698f156 100644 --- a/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py +++ b/metadata-ingestion/tests/integration/bigquery_v2/test_bigquery_queries.py @@ -1,4 +1,5 @@ import json +import os from datetime import datetime from pathlib import Path from unittest.mock import patch @@ -6,7 +7,9 @@ import pytest from freezegun import freeze_time +from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.source.bigquery_v2.bigquery_queries import ( + BigQueryQueriesSource, BigQueryQueriesSourceReport, ) from datahub.metadata.urns import CorpUserUrn @@ -93,3 +96,16 @@ def test_queries_ingestion(project_client, client, pytestconfig, monkeypatch, tm output_path=mcp_output_path, golden_path=mcp_golden_path, ) + + +@patch("google.cloud.bigquery.Client") +@patch("google.cloud.resourcemanager_v3.ProjectsClient") +def test_source_close_cleans_tmp(projects_client, client, tmp_path): + with patch("tempfile.tempdir", str(tmp_path)): + source = BigQueryQueriesSource.create( + {"project_ids": ["project1"]}, PipelineContext("run-id") + ) + assert len(os.listdir(tmp_path)) > 0 + # This closes QueriesExtractor which in turn closes SqlParsingAggregator + source.close() + assert len(os.listdir(tmp_path)) == 0 diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py new file mode 100644 index 0000000000000..82f5691bcee3d --- /dev/null +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py @@ -0,0 +1,24 @@ +import os +from unittest.mock import patch + +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.source.snowflake.snowflake_queries import SnowflakeQueriesSource + + +@patch("snowflake.connector.connect") +def test_source_close_cleans_tmp(snowflake_connect, tmp_path): + with patch("tempfile.tempdir", str(tmp_path)): + source = SnowflakeQueriesSource.create( + { + "connection": { + "account_id": "ABC12345.ap-south-1.aws", + "username": "TST_USR", + "password": "TST_PWD", + } + }, + PipelineContext("run-id"), + ) + assert len(os.listdir(tmp_path)) > 0 + # This closes QueriesExtractor which in turn closes SqlParsingAggregator + source.close() + assert len(os.listdir(tmp_path)) == 0 diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py index 0d21936a74d07..849d550ef69c5 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sql_aggregator.py @@ -1,5 +1,7 @@ +import os import pathlib from datetime import datetime, timezone +from unittest.mock import patch import pytest from freezegun import freeze_time @@ -661,3 +663,23 @@ def test_basic_usage(pytestconfig: pytest.Config) -> None: outputs=mcps, golden_path=RESOURCE_DIR / "test_basic_usage.json", ) + + +def test_sql_aggreator_close_cleans_tmp(tmp_path): + frozen_timestamp = parse_user_datetime(FROZEN_TIME) + with patch("tempfile.tempdir", str(tmp_path)): + aggregator = SqlParsingAggregator( + platform="redshift", + generate_lineage=False, + generate_usage_statistics=True, + generate_operations=False, + usage_config=BaseUsageConfig( + start_time=get_time_bucket(frozen_timestamp, BucketDuration.DAY), + end_time=frozen_timestamp, + ), + generate_queries=True, + generate_query_usage_statistics=True, + ) + assert len(os.listdir(tmp_path)) > 0 + aggregator.close() + assert len(os.listdir(tmp_path)) == 0