Skip to content

Commit

Permalink
feat(ingest): ensure sqlite file delete on clean exit (#11612)
Browse files Browse the repository at this point in the history
  • Loading branch information
mayurinehate authored Oct 18, 2024
1 parent 6e3724b commit 179a671
Show file tree
Hide file tree
Showing 12 changed files with 203 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

self.report.set_ingestion_stage("*", QUERIES_EXTRACTION)

queries_extractor = BigQueryQueriesExtractor(
with BigQueryQueriesExtractor(
connection=self.config.get_bigquery_client(),
schema_api=self.bq_schema_extractor.schema_api,
config=BigQueryQueriesExtractorConfig(
Expand All @@ -288,9 +288,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
identifiers=self.identifiers,
schema_resolver=self.sql_parser_schema_resolver,
discovered_tables=self.bq_schema_extractor.table_refs,
)
self.report.queries_extractor = queries_extractor.report
yield from queries_extractor.get_workunits_internal()
) as queries_extractor:
self.report.queries_extractor = queries_extractor.report
yield from queries_extractor.get_workunits_internal()

else:
if self.config.include_usage_statistics:
yield from self.usage_extractor.get_usage_workunits(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:

def get_report(self) -> BigQueryQueriesSourceReport:
return self.report

def close(self) -> None:
self.queries_extractor.close()
self.connection.close()
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
BaseTimeWindowConfig,
get_time_bucket,
)
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.source import SourceReport
from datahub.ingestion.api.source_helpers import auto_workunit
from datahub.ingestion.api.workunit import MetadataWorkUnit
Expand Down Expand Up @@ -114,7 +115,7 @@ class BigQueryQueriesExtractorConfig(BigQueryBaseConfig):
)


class BigQueryQueriesExtractor:
class BigQueryQueriesExtractor(Closeable):
"""
Extracts query audit log and generates usage/lineage/operation workunits.
Expand Down Expand Up @@ -181,6 +182,7 @@ def __init__(
is_allowed_table=self.is_allowed_table,
format_queries=False,
)

self.report.sql_aggregator = self.aggregator.report
self.report.num_discovered_tables = (
len(self.discovered_tables) if self.discovered_tables else None
Expand Down Expand Up @@ -273,12 +275,14 @@ def get_workunits_internal(
self.report.num_unique_queries = len(queries_deduped)
logger.info(f"Found {self.report.num_unique_queries} unique queries")

with self.report.audit_log_load_timer:
with self.report.audit_log_load_timer, queries_deduped:
i = 0
for _, query_instances in queries_deduped.items():
for query in query_instances.values():
if i > 0 and i % 10000 == 0:
logger.info(f"Added {i} query log entries to SQL aggregator")
logger.info(
f"Added {i} query log equeries_dedupedntries to SQL aggregator"
)
if self.report.sql_aggregator:
logger.info(self.report.sql_aggregator.as_string())

Expand All @@ -287,6 +291,11 @@ def get_workunits_internal(

yield from auto_workunit(self.aggregator.gen_metadata())

if not use_cached_audit_log:
queries.close()
shared_connection.close()
audit_log_file.unlink(missing_ok=True)

def deduplicate_queries(
self, queries: FileBackedList[ObservedQuery]
) -> FileBackedDict[Dict[int, ObservedQuery]]:
Expand Down Expand Up @@ -404,6 +413,9 @@ def _parse_audit_log_row(self, row: BigQueryJob) -> ObservedQuery:

return entry

def close(self) -> None:
self.aggregator.close()


def _extract_query_text(row: BigQueryJob) -> str:
# We wrap select statements in a CTE to make them parseable as DML statement.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import redshift_connector

from datahub.emitter import mce_builder
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.redshift.config import LineageMode, RedshiftConfig
Expand Down Expand Up @@ -39,7 +40,7 @@
logger = logging.getLogger(__name__)


class RedshiftSqlLineageV2:
class RedshiftSqlLineageV2(Closeable):
# does lineage and usage based on SQL parsing.

def __init__(
Expand All @@ -56,6 +57,7 @@ def __init__(
self.context = context

self.database = database

self.aggregator = SqlParsingAggregator(
platform=self.platform,
platform_instance=self.config.platform_instance,
Expand Down Expand Up @@ -436,3 +438,6 @@ def generate(self) -> Iterable[MetadataWorkUnit]:
message="Unexpected error(s) while attempting to extract lineage from SQL queries. See the full logs for more details.",
context=f"Query Parsing Failures: {self.aggregator.report.observed_query_parse_failures}",
)

def close(self) -> None:
self.aggregator.close()
Original file line number Diff line number Diff line change
Expand Up @@ -451,24 +451,23 @@ def _extract_metadata(
)

if self.config.use_lineage_v2:
lineage_extractor = RedshiftSqlLineageV2(
with RedshiftSqlLineageV2(
config=self.config,
report=self.report,
context=self.ctx,
database=database,
redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
)

yield from lineage_extractor.aggregator.register_schemas_from_stream(
self.process_schemas(connection, database)
)
) as lineage_extractor:
yield from lineage_extractor.aggregator.register_schemas_from_stream(
self.process_schemas(connection, database)
)

self.report.report_ingestion_stage_start(LINEAGE_EXTRACTION)
yield from self.extract_lineage_v2(
connection=connection,
database=database,
lineage_extractor=lineage_extractor,
)
self.report.report_ingestion_stage_start(LINEAGE_EXTRACTION)
yield from self.extract_lineage_v2(
connection=connection,
database=database,
lineage_extractor=lineage_extractor,
)

all_tables = self.get_all_tables()
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from datahub.configuration.connection_resolver import auto_connection_resolver
from datahub.configuration.oauth import OAuthConfiguration, OAuthIdentityProvider
from datahub.configuration.validate_field_rename import pydantic_renamed_field
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.source.snowflake.constants import (
CLIENT_PREFETCH_THREADS,
CLIENT_SESSION_KEEP_ALIVE,
Expand Down Expand Up @@ -364,7 +365,7 @@ def get_connection(self) -> "SnowflakeConnection":
) from e


class SnowflakeConnection:
class SnowflakeConnection(Closeable):
_connection: NativeSnowflakeConnection

def __init__(self, connection: NativeSnowflakeConnection):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import dataclasses
import functools
import json
Expand All @@ -17,6 +18,7 @@
BaseTimeWindowConfig,
BucketDuration,
)
from datahub.ingestion.api.closeable import Closeable
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.report import Report
from datahub.ingestion.api.source import Source, SourceReport
Expand Down Expand Up @@ -121,7 +123,7 @@ class SnowflakeQueriesSourceReport(SourceReport):
queries_extractor: Optional[SnowflakeQueriesExtractorReport] = None


class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin):
class SnowflakeQueriesExtractor(SnowflakeStructuredReportMixin, Closeable):
def __init__(
self,
connection: SnowflakeConnection,
Expand All @@ -143,28 +145,33 @@ def __init__(

self._structured_report = structured_report

self.aggregator = SqlParsingAggregator(
platform=self.identifiers.platform,
platform_instance=self.identifiers.identifier_config.platform_instance,
env=self.identifiers.identifier_config.env,
schema_resolver=schema_resolver,
graph=graph,
eager_graph_load=False,
generate_lineage=self.config.include_lineage,
generate_queries=self.config.include_queries,
generate_usage_statistics=self.config.include_usage_statistics,
generate_query_usage_statistics=self.config.include_query_usage_statistics,
usage_config=BaseUsageConfig(
bucket_duration=self.config.window.bucket_duration,
start_time=self.config.window.start_time,
end_time=self.config.window.end_time,
user_email_pattern=self.config.user_email_pattern,
# TODO make the rest of the fields configurable
),
generate_operations=self.config.include_operations,
is_temp_table=self.is_temp_table,
is_allowed_table=self.is_allowed_table,
format_queries=False,
# The exit stack helps ensure that we close all the resources we open.
self._exit_stack = contextlib.ExitStack()

self.aggregator: SqlParsingAggregator = self._exit_stack.enter_context(
SqlParsingAggregator(
platform=self.identifiers.platform,
platform_instance=self.identifiers.identifier_config.platform_instance,
env=self.identifiers.identifier_config.env,
schema_resolver=schema_resolver,
graph=graph,
eager_graph_load=False,
generate_lineage=self.config.include_lineage,
generate_queries=self.config.include_queries,
generate_usage_statistics=self.config.include_usage_statistics,
generate_query_usage_statistics=self.config.include_query_usage_statistics,
usage_config=BaseUsageConfig(
bucket_duration=self.config.window.bucket_duration,
start_time=self.config.window.start_time,
end_time=self.config.window.end_time,
user_email_pattern=self.config.user_email_pattern,
# TODO make the rest of the fields configurable
),
generate_operations=self.config.include_operations,
is_temp_table=self.is_temp_table,
is_allowed_table=self.is_allowed_table,
format_queries=False,
)
)
self.report.sql_aggregator = self.aggregator.report

Expand Down Expand Up @@ -248,6 +255,10 @@ def get_workunits_internal(
self.aggregator.add(query)

yield from auto_workunit(self.aggregator.gen_metadata())
if not use_cached_audit_log:
queries.close()
shared_connection.close()
audit_log_file.unlink(missing_ok=True)

def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:
# Derived from _populate_external_lineage_from_copy_history.
Expand Down Expand Up @@ -426,6 +437,9 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery:
)
return entry

def close(self) -> None:
self._exit_stack.close()


class SnowflakeQueriesSource(Source):
def __init__(self, ctx: PipelineContext, config: SnowflakeQueriesSourceConfig):
Expand Down Expand Up @@ -468,6 +482,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
def get_report(self) -> SnowflakeQueriesSourceReport:
return self.report

def close(self) -> None:
self.connection.close()
self.queries_extractor.close()


# Make sure we don't try to generate too much info for a single query.
_MAX_TABLES_PER_QUERY = 20
Expand Down
Loading

0 comments on commit 179a671

Please sign in to comment.