diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 2b2dcf860cdb07..12e5fb72b00de8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -138,12 +138,20 @@ class SnowflakeIdentifierConfig( description="Whether to convert dataset urns to lowercase.", ) - -class SnowflakeUsageConfig(BaseUsageConfig): email_domain: Optional[str] = pydantic.Field( default=None, description="Email domain of your organization so users can be displayed on UI appropriately.", ) + + email_as_user_identifier: bool = Field( + default=True, + description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is " + "provided, generates email addresses for snowflake users with unset emails, based on their " + "username.", + ) + + +class SnowflakeUsageConfig(BaseUsageConfig): apply_view_usage_to_tables: bool = pydantic.Field( default=False, description="Whether to apply view's usage to its base tables. If set to True, usage is applied to base tables only.", @@ -267,13 +275,6 @@ class SnowflakeV2Config( " Map of share name -> details of share.", ) - email_as_user_identifier: bool = Field( - default=True, - description="Format user urns as an email, if the snowflake user's email is set. If `email_domain` is " - "provided, generates email addresses for snowflake users with unset emails, based on their " - "username.", - ) - include_assertion_results: bool = Field( default=False, description="Whether to ingest assertion run results for assertions created using Datahub" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index 174aad0bddd4a8..36825dc33fe7dc 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -66,6 +66,11 @@ logger = logging.getLogger(__name__) +# Define a type alias +UserName = str +UserEmail = str +UsersMapping = Dict[UserName, UserEmail] + class SnowflakeQueriesExtractorConfig(ConfigModel): # TODO: Support stateful ingestion for the time windows. @@ -114,11 +119,13 @@ class SnowflakeQueriesSourceConfig( class SnowflakeQueriesExtractorReport(Report): copy_history_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) query_log_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) + users_fetch_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) audit_log_load_timer: PerfTimer = dataclasses.field(default_factory=PerfTimer) sql_aggregator: Optional[SqlAggregatorReport] = None num_ddl_queries_dropped: int = 0 + num_users: int = 0 @dataclass @@ -225,6 +232,9 @@ def is_allowed_table(self, name: str) -> bool: def get_workunits_internal( self, ) -> Iterable[MetadataWorkUnit]: + with self.report.users_fetch_timer: + users = self.fetch_users() + # TODO: Add some logic to check if the cached audit log is stale or not. audit_log_file = self.local_temp_path / "audit_log.sqlite" use_cached_audit_log = audit_log_file.exists() @@ -248,7 +258,7 @@ def get_workunits_internal( queries.append(entry) with self.report.query_log_fetch_timer: - for entry in self.fetch_query_log(): + for entry in self.fetch_query_log(users): queries.append(entry) with self.report.audit_log_load_timer: @@ -263,6 +273,25 @@ def get_workunits_internal( shared_connection.close() audit_log_file.unlink(missing_ok=True) + def fetch_users(self) -> UsersMapping: + users: UsersMapping = dict() + with self.structured_reporter.report_exc("Error fetching users from Snowflake"): + logger.info("Fetching users from Snowflake") + query = SnowflakeQuery.get_all_users() + resp = self.connection.query(query) + + for row in resp: + try: + users[row["NAME"]] = row["EMAIL"] + self.report.num_users += 1 + except Exception as e: + self.structured_reporter.warning( + "Error parsing user row", + context=f"{row}", + exc=e, + ) + return users + def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: # Derived from _populate_external_lineage_from_copy_history. @@ -298,7 +327,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]: yield result def fetch_query_log( - self, + self, users: UsersMapping ) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]: query_log_query = _build_enriched_query_log_query( start_time=self.config.window.start_time, @@ -319,7 +348,7 @@ def fetch_query_log( assert isinstance(row, dict) try: - entry = self._parse_audit_log_row(row) + entry = self._parse_audit_log_row(row, users) except Exception as e: self.structured_reporter.warning( "Error parsing query log row", @@ -331,7 +360,7 @@ def fetch_query_log( yield entry def _parse_audit_log_row( - self, row: Dict[str, Any] + self, row: Dict[str, Any], users: UsersMapping ) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]: json_fields = { "DIRECT_OBJECTS_ACCESSED", @@ -430,9 +459,11 @@ def _parse_audit_log_row( ) ) - # TODO: Fetch email addresses from Snowflake to map user -> email - # TODO: Support email_domain fallback for generating user urns. - user = CorpUserUrn(self.identifiers.snowflake_identifier(res["user_name"])) + user = CorpUserUrn( + self.identifiers.get_user_identifier( + res["user_name"], users.get(res["user_name"]) + ) + ) timestamp: datetime = res["query_start_time"] timestamp = timestamp.astimezone(timezone.utc) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py index a94b39476b2c22..40bcfb514efd23 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_query.py @@ -947,4 +947,8 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str: AND METRIC_NAME ilike '{pattern}' escape '{escape_pattern}' ORDER BY MEASUREMENT_TIME ASC; -""" + """ + + @staticmethod + def get_all_users() -> str: + return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS""" diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index aff15386c50833..4bdf559f293b51 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -342,10 +342,9 @@ def _map_user_counts( filtered_user_counts.append( DatasetUserUsageCounts( user=make_user_urn( - self.get_user_identifier( + self.identifiers.get_user_identifier( user_count["user_name"], user_email, - self.config.email_as_user_identifier, ) ), count=user_count["total"], @@ -453,9 +452,7 @@ def _get_operation_aspect_work_unit( reported_time: int = int(time.time() * 1000) last_updated_timestamp: int = int(start_time.timestamp() * 1000) user_urn = make_user_urn( - self.get_user_identifier( - user_name, user_email, self.config.email_as_user_identifier - ) + self.identifiers.get_user_identifier(user_name, user_email) ) # NOTE: In earlier `snowflake-usage` connector this was base_objects_accessed, which is incorrect diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py index 8e0c97aa135e84..885bee1ccdb908 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_utils.py @@ -300,6 +300,28 @@ def get_quoted_identifier_for_schema(db_name, schema_name): def get_quoted_identifier_for_table(db_name, schema_name, table_name): return f'"{db_name}"."{schema_name}"."{table_name}"' + # Note - decide how to construct user urns. + # Historically urns were created using part before @ from user's email. + # Users without email were skipped from both user entries as well as aggregates. + # However email is not mandatory field in snowflake user, user_name is always present. + def get_user_identifier( + self, + user_name: str, + user_email: Optional[str], + ) -> str: + if user_email: + return self.snowflake_identifier( + user_email + if self.identifier_config.email_as_user_identifier is True + else user_email.split("@")[0] + ) + return self.snowflake_identifier( + f"{user_name}@{self.identifier_config.email_domain}" + if self.identifier_config.email_as_user_identifier is True + and self.identifier_config.email_domain is not None + else user_name + ) + class SnowflakeCommonMixin(SnowflakeStructuredReportMixin): platform = "snowflake" @@ -315,24 +337,6 @@ def structured_reporter(self) -> SourceReport: def identifiers(self) -> SnowflakeIdentifierBuilder: return SnowflakeIdentifierBuilder(self.config, self.report) - # Note - decide how to construct user urns. - # Historically urns were created using part before @ from user's email. - # Users without email were skipped from both user entries as well as aggregates. - # However email is not mandatory field in snowflake user, user_name is always present. - def get_user_identifier( - self, - user_name: str, - user_email: Optional[str], - email_as_user_identifier: bool, - ) -> str: - if user_email: - return self.identifiers.snowflake_identifier( - user_email - if email_as_user_identifier is True - else user_email.split("@")[0] - ) - return self.identifiers.snowflake_identifier(user_name) - # TODO: Revisit this after stateful ingestion can commit checkpoint # for failures that do not affect the checkpoint # TODO: Add additional parameters to match the signature of the .warning and .failure methods diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py index 82f5691bcee3de..ae0f23d93215d4 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_queries.py @@ -22,3 +22,58 @@ def test_source_close_cleans_tmp(snowflake_connect, tmp_path): # This closes QueriesExtractor which in turn closes SqlParsingAggregator source.close() assert len(os.listdir(tmp_path)) == 0 + + +@patch("snowflake.connector.connect") +def test_user_identifiers_email_as_identifier(snowflake_connect, tmp_path): + source = SnowflakeQueriesSource.create( + { + "connection": { + "account_id": "ABC12345.ap-south-1.aws", + "username": "TST_USR", + "password": "TST_PWD", + }, + "email_as_user_identifier": True, + "email_domain": "example.com", + }, + PipelineContext("run-id"), + ) + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username@example.com" + ) + assert ( + source.identifiers.get_user_identifier("username", None) + == "username@example.com" + ) + + # We'd do best effort to use email as identifier, but would keep username as is, + # if email can't be formed. + source.identifiers.identifier_config.email_domain = None + + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username@example.com" + ) + + assert source.identifiers.get_user_identifier("username", None) == "username" + + +@patch("snowflake.connector.connect") +def test_user_identifiers_username_as_identifier(snowflake_connect, tmp_path): + source = SnowflakeQueriesSource.create( + { + "connection": { + "account_id": "ABC12345.ap-south-1.aws", + "username": "TST_USR", + "password": "TST_PWD", + }, + "email_as_user_identifier": False, + }, + PipelineContext("run-id"), + ) + assert ( + source.identifiers.get_user_identifier("username", "username@example.com") + == "username" + ) + assert source.identifiers.get_user_identifier("username", None) == "username"