Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/snowflake): support lineage via rename and swap using que… #11600

Merged
merged 9 commits into from
Oct 23, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
)
from datahub.metadata.urns import DatasetUrn
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sql_parsing_aggregator import TableRename
from datahub.sql_parsing.sqlglot_utils import get_dialect, parse_statement
from datahub.utilities import memory_footprint
from datahub.utilities.dedup_list import deduplicate_list
Expand Down Expand Up @@ -504,21 +505,21 @@ def _populate_lineage_map(
self.report_status(f"extract-{lineage_type.name}", False)

def _update_lineage_map_for_table_renames(
self, table_renames: Dict[str, str]
self, table_renames: Dict[str, TableRename]
) -> None:
if not table_renames:
return

logger.info(f"Updating lineage map for {len(table_renames)} table renames")
for new_table_urn, prev_table_urn in table_renames.items():
for entry in table_renames.values():
# This table was renamed from some other name, copy in the lineage
# for the previous name as well.
prev_table_lineage = self._lineage_map.get(prev_table_urn)
prev_table_lineage = self._lineage_map.get(entry.original_urn)
if prev_table_lineage:
logger.debug(
f"including lineage for {prev_table_urn} in {new_table_urn} due to table rename"
f"including lineage for {entry.original_urn} in {entry.new_urn} due to table rename"
)
self._lineage_map[new_table_urn].merge_lineage(
self._lineage_map[entry.new_urn].merge_lineage(
upstreams=prev_table_lineage.upstreams,
cll=prev_table_lineage.cll,
)
Expand Down Expand Up @@ -672,7 +673,7 @@ def populate_lineage(
for db, schemas in all_tables.items()
}

table_renames: Dict[str, str] = {}
table_renames: Dict[str, TableRename] = {}
if self.config.include_table_rename_lineage:
table_renames, all_tables_set = self._process_table_renames(
database=database,
Expand Down Expand Up @@ -851,11 +852,11 @@ def _process_table_renames(
database: str,
connection: redshift_connector.Connection,
all_tables: Dict[str, Dict[str, Set[str]]],
) -> Tuple[Dict[str, str], Dict[str, Dict[str, Set[str]]]]:
) -> Tuple[Dict[str, TableRename], Dict[str, Dict[str, Set[str]]]]:
logger.info(f"Processing table renames for db {database}")

# new urn -> prev urn
table_renames: Dict[str, str] = {}
table_renames: Dict[str, TableRename] = {}

query = self.queries.alter_table_rename_query(
db_name=database,
Expand Down Expand Up @@ -893,7 +894,7 @@ def _process_table_renames(
env=self.config.env,
)

table_renames[new_urn] = prev_urn
table_renames[new_urn] = TableRename(prev_urn, new_urn, query_text)

# We want to generate lineage for the previous name too.
all_tables[database][schema].add(prev_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,8 @@ def build(
lambda: collections.defaultdict(set)
),
)
for new_urn, original_urn in table_renames.items():
self.aggregator.add_table_rename(
original_urn=original_urn, new_urn=new_urn
)
for entry in table_renames.values():
self.aggregator.add_table_rename(entry)

if self.config.table_lineage_mode in {
LineageMode.SQL_BASED,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
PreparsedQuery,
SqlAggregatorReport,
SqlParsingAggregator,
TableRename,
TableSwap,
)
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import (
Expand Down Expand Up @@ -116,6 +118,8 @@ class SnowflakeQueriesExtractorReport(Report):
audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer)
sql_aggregator: Optional[SqlAggregatorReport] = None

num_ddl_queries_dropped: int = 0


@dataclass
class SnowflakeQueriesSourceReport(SourceReport):
Expand Down Expand Up @@ -225,7 +229,9 @@ def get_workunits_internal(
audit_log_file = self.local_temp_path / "audit_log.sqlite"
use_cached_audit_log = audit_log_file.exists()

queries: FileBackedList[Union[KnownLineageMapping, PreparsedQuery]]
queries: FileBackedList[
Union[KnownLineageMapping, PreparsedQuery, TableRename, TableSwap]
]
if use_cached_audit_log:
logger.info("Using cached audit log")
shared_connection = ConnectionWrapper(audit_log_file)
Expand All @@ -235,7 +241,7 @@ def get_workunits_internal(

shared_connection = ConnectionWrapper(audit_log_file)
queries = FileBackedList(shared_connection)
entry: Union[KnownLineageMapping, PreparsedQuery]
entry: Union[KnownLineageMapping, PreparsedQuery, TableRename, TableSwap]

with self.report.copy_history_fetch_timer:
for entry in self.fetch_copy_history():
Expand Down Expand Up @@ -296,7 +302,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:

def fetch_query_log(
self,
) -> Iterable[PreparsedQuery]:
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
end_time=self.config.window.end_time,
Expand Down Expand Up @@ -324,12 +330,16 @@ def fetch_query_log(
exc=e,
)
else:
yield entry
if entry:
yield entry

def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery:
def _parse_audit_log_row(
self, row: Dict[str, Any]
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
json_fields = {
"DIRECT_OBJECTS_ACCESSED",
"OBJECTS_MODIFIED",
"OBJECT_MODIFIED_BY_DDL",
}

res = {}
Expand All @@ -341,6 +351,17 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery:

direct_objects_accessed = res["direct_objects_accessed"]
objects_modified = res["objects_modified"]
object_modified_by_ddl = res["object_modified_by_ddl"]

if object_modified_by_ddl and not objects_modified:
ddl_entry: Optional[Union[TableRename, TableSwap]] = None
with self.structured_reporter.report_exc(
"Error fetching ddl lineage from Snowflake"
):
ddl_entry = self.parse_ddl_query(
res["query_text"], object_modified_by_ddl
)
return ddl_entry

upstreams = []
column_usage = {}
Expand Down Expand Up @@ -437,6 +458,45 @@ def _parse_audit_log_row(self, row: Dict[str, Any]) -> PreparsedQuery:
)
return entry

def parse_ddl_query(
self, query: str, object_modified_by_ddl: dict
) -> Optional[Union[TableRename, TableSwap]]:
if object_modified_by_ddl[
"operationType"
] == "ALTER" and object_modified_by_ddl["properties"].get("swapTargetName"):
urn1 = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
object_modified_by_ddl["objectName"]
)
)

urn2 = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
object_modified_by_ddl["properties"]["swapTargetName"]["value"]
)
)

return TableSwap(urn1, urn2, query)
elif object_modified_by_ddl[
"operationType"
] == "RENAME_TABLE" and object_modified_by_ddl["properties"].get("objectName"):
original_un = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
object_modified_by_ddl["objectName"]
)
)

new_urn = self.identifiers.gen_dataset_urn(
self.identifiers.get_dataset_identifier_from_qualified_name(
object_modified_by_ddl["properties"]["objectName"]["value"]
)
)

return TableRename(original_un, new_urn, query)
else:
self.report.num_ddl_queries_dropped += 1
return None

def close(self) -> None:
self._exit_stack.close()

Expand Down Expand Up @@ -542,6 +602,7 @@ def _build_enriched_query_log_query(
user_name,
direct_objects_accessed,
objects_modified,
object_modified_by_ddl
FROM
snowflake.account_usage.access_history
WHERE
Expand All @@ -563,8 +624,9 @@ def _build_enriched_query_log_query(
) as direct_objects_accessed,
-- TODO: Drop the columns.baseSources subfield.
FILTER(objects_modified, o -> o:objectDomain IN {SnowflakeQuery.ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER}) as objects_modified,
case when object_modified_by_ddl:objectDomain IN {SnowflakeQuery.ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER} then object_modified_by_ddl else null end as object_modified_by_ddl
FROM raw_access_history
WHERE ( array_size(direct_objects_accessed) > 0 or array_size(objects_modified) > 0 )
WHERE ( array_size(direct_objects_accessed) > 0 or array_size(objects_modified) > 0 or object_modified_by_ddl is not null )
)
, query_access_history AS (
SELECT
Expand All @@ -586,6 +648,7 @@ def _build_enriched_query_log_query(
q.role_name AS "ROLE_NAME",
a.direct_objects_accessed,
a.objects_modified,
a.object_modified_by_ddl
FROM deduplicated_queries q
JOIN filtered_access_history a USING (query_id)
)
Expand Down
Loading
Loading