Skip to content

Commit

Permalink
feat(ingest): add query formatting to sql aggregator (datahub-project…
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Mar 11, 2024
1 parent 2fe3583 commit 7e2076e
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
infer_output_schema,
sqlglot_lineage,
)
from datahub.sql_parsing.sqlglot_utils import generate_hash, get_query_fingerprint
from datahub.sql_parsing.sqlglot_utils import (
generate_hash,
get_query_fingerprint,
try_format_query,
)
from datahub.utilities.cooperative_timeout import CooperativeTimeoutError
from datahub.utilities.file_backed_collections import (
ConnectionWrapper,
Expand Down Expand Up @@ -180,6 +184,7 @@ def __init__(
generate_operations: bool = False,
usage_config: Optional[BaseUsageConfig] = None,
is_temp_table: Optional[Callable[[UrnStr], bool]] = None,
format_queries: bool = True,
query_log: QueryLogSetting = QueryLogSetting.DISABLED,
) -> None:
self.platform = DataPlatformUrn(platform)
Expand All @@ -202,6 +207,7 @@ def __init__(
# can be used by BQ where we have a "temp_table_dataset_prefix"
self.is_temp_table = is_temp_table

self.format_queries = format_queries
self.query_log = query_log

# Set up the schema resolver.
Expand Down Expand Up @@ -328,6 +334,11 @@ def _initialize_schema_resolver_from_graph(self, graph: DataHubGraph) -> None:
env=self.env,
)

def _maybe_format_query(self, query: str) -> str:
if self.format_queries:
return try_format_query(query, self.platform.platform_name)
return query

def add_known_query_lineage(
self, known_query_lineage: KnownQueryLineageInfo, merge_lineage: bool = False
) -> None:
Expand All @@ -342,21 +353,23 @@ def add_known_query_lineage(
Args:
known_query_lineage: The known query lineage information.
merge_lineage: Whether to merge the lineage with any existing lineage
for the query ID.
"""

self.report.num_known_query_lineage += 1

# Generate a fingerprint for the query.
query_fingerprint = get_query_fingerprint(
known_query_lineage.query_text, self.platform.platform_name
known_query_lineage.query_text, platform=self.platform.platform_name
)
# TODO format the query text?
formatted_query = self._maybe_format_query(known_query_lineage.query_text)

# Register the query.
self._add_to_query_map(
QueryMetadata(
query_id=query_fingerprint,
formatted_query_string=known_query_lineage.query_text,
formatted_query_string=formatted_query,
session_id=known_query_lineage.session_id or _MISSING_SESSION_ID,
query_type=known_query_lineage.query_type,
lineage_type=models.DatasetLineageTypeClass.TRANSFORMED,
Expand Down Expand Up @@ -499,6 +512,9 @@ def add_observed_query(
elif parsed.debug_info.column_error:
self.report.num_observed_queries_column_failed += 1

# Format the query.
formatted_query = self._maybe_format_query(query)

# Register the query's usage.
if not self._usage_aggregator:
pass # usage is not enabled
Expand All @@ -518,7 +534,7 @@ def add_observed_query(
self._usage_aggregator.aggregate_event(
resource=upstream_urn,
start_time=query_timestamp,
query=query,
query=formatted_query,
user=user.urn() if user else None,
fields=sorted(upstream_fields.get(upstream_urn, [])),
count=usage_multiplier,
Expand All @@ -540,7 +556,7 @@ def add_observed_query(
self._add_to_query_map(
QueryMetadata(
query_id=query_fingerprint,
formatted_query_string=query, # TODO replace with formatted query string
formatted_query_string=formatted_query,
session_id=session_id,
query_type=parsed.query_type,
lineage_type=models.DatasetLineageTypeClass.TRANSFORMED,
Expand Down Expand Up @@ -655,12 +671,15 @@ def _process_view_definition(
self.report.num_views_column_failed += 1

query_fingerprint = self._view_query_id(view_urn)
formatted_view_definition = self._maybe_format_query(
view_definition.view_definition
)

# Register the query.
self._add_to_query_map(
QueryMetadata(
query_id=query_fingerprint,
formatted_query_string=view_definition.view_definition,
formatted_query_string=formatted_view_definition,
session_id=_MISSING_SESSION_ID,
query_type=QueryType.CREATE_VIEW,
lineage_type=models.DatasetLineageTypeClass.VIEW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ def _sqlglot_lineage_inner(
original_statement, dialect=dialect
)
query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug(
original_statement, dialect=dialect
original_statement, platform=dialect
)
return SqlParsingResult(
query_type=query_type,
Expand Down
50 changes: 42 additions & 8 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ def parse_statement(
return statement


def _expression_to_string(
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr
) -> str:
if isinstance(expression, str):
return expression
return expression.sql(dialect=get_dialect(platform))


def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str:
"""
Generalize/normalize a SQL query.
Expand Down Expand Up @@ -121,24 +129,28 @@ def generate_hash(text: str) -> str:


def get_query_fingerprint_debug(
expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr
) -> Tuple[str, str]:
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr
) -> Tuple[str, Optional[str]]:
try:
dialect = get_dialect(dialect)
dialect = get_dialect(platform)
expression_sql = generalize_query(expression, dialect=dialect)
except (ValueError, sqlglot.errors.SqlglotError) as e:
if not isinstance(expression, str):
raise

logger.debug("Failed to generalize query for fingerprinting: %s", e)
expression_sql = expression
expression_sql = None

fingerprint = generate_hash(expression_sql)
fingerprint = generate_hash(
expression_sql
if expression_sql is not None
else _expression_to_string(expression, platform=platform)
)
return fingerprint, expression_sql


def get_query_fingerprint(
expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr
) -> str:
"""Get a fingerprint for a SQL query.
Expand All @@ -154,13 +166,35 @@ def get_query_fingerprint(
Args:
expression: The SQL query to fingerprint.
dialect: The SQL dialect to use.
platform: The SQL dialect to use.
Returns:
The fingerprint for the SQL query.
"""

return get_query_fingerprint_debug(expression, dialect)[0]
return get_query_fingerprint_debug(expression, platform)[0]


def try_format_query(expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr) -> str:
"""Format a SQL query.
If the query cannot be formatted, the original query is returned unchanged.
Args:
expression: The SQL query to format.
platform: The SQL dialect to use.
Returns:
The formatted SQL query.
"""

try:
dialect = get_dialect(platform)
expression = parse_statement(expression, dialect=dialect)
return expression.sql(dialect=dialect, pretty=True)
except Exception as e:
logger.debug("Failed to format query: %s", e)
return _expression_to_string(expression, platform=platform)


def detach_ctes(
Expand Down
3 changes: 3 additions & 0 deletions metadata-ingestion/src/datahub/utilities/sql_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

logger = logging.getLogger(__name__)

# TODO: The sql query formatting functionality is duplicated by the try_format_query method,
# which is powered by sqlglot instead of sqlparse.


def format_sql_query(query: str, **options: Any) -> str:
try:
Expand Down
Loading

0 comments on commit 7e2076e

Please sign in to comment.