From d13553f53ad9e7592256cd88e78eef0ca95832e4 Mon Sep 17 00:00:00 2001 From: Harshal Sheth Date: Tue, 24 Oct 2023 12:24:50 -0700 Subject: [PATCH] feat(sqlparser): extract CLL from `update`s (#9078) --- .../src/datahub/utilities/sqlglot_lineage.py | 68 +++++++++++-- .../test_snowflake_update_from_table.json | 56 +++++++++++ .../test_snowflake_update_hardcoded.json | 35 +++++++ .../unit/sql_parsing/test_sqlglot_lineage.py | 96 +++++++++++++++++++ 4 files changed, 246 insertions(+), 9 deletions(-) create mode 100644 metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_from_table.json create mode 100644 metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_hardcoded.json diff --git a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py index 97121b368f5078..526d90b2a1bfab 100644 --- a/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/utilities/sqlglot_lineage.py @@ -745,6 +745,47 @@ def _extract_select_from_create( return statement +_UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT: Set[str] = set( + sqlglot.exp.Update.arg_types.keys() +) - set(sqlglot.exp.Select.arg_types.keys()) + + +def _extract_select_from_update( + statement: sqlglot.exp.Update, +) -> sqlglot.exp.Select: + statement = statement.copy() + + # The "SET" expressions need to be converted. + # For the update command, it'll be a list of EQ expressions, but the select + # should contain aliased columns. + new_expressions = [] + for expr in statement.expressions: + if isinstance(expr, sqlglot.exp.EQ) and isinstance( + expr.left, sqlglot.exp.Column + ): + new_expressions.append( + sqlglot.exp.Alias( + this=expr.right, + alias=expr.left.this, + ) + ) + else: + # If we don't know how to convert it, just leave it as-is. If this causes issues, + # they'll get caught later. + new_expressions.append(expr) + + return sqlglot.exp.Select( + **{ + **{ + k: v + for k, v in statement.args.items() + if k not in _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT + }, + "expressions": new_expressions, + } + ) + + def _is_create_table_ddl(statement: sqlglot.exp.Expression) -> bool: return isinstance(statement, sqlglot.exp.Create) and isinstance( statement.this, sqlglot.exp.Schema @@ -767,6 +808,9 @@ def _try_extract_select( elif isinstance(statement, sqlglot.exp.Insert): # TODO Need to map column renames in the expressions part of the statement. statement = statement.expression + elif isinstance(statement, sqlglot.exp.Update): + # Assumption: the output table is already captured in the modified tables list. + statement = _extract_select_from_update(statement) elif isinstance(statement, sqlglot.exp.Create): # TODO May need to map column renames. # Assumption: the output table is already captured in the modified tables list. @@ -942,19 +986,25 @@ def _sqlglot_lineage_inner( ) # Simplify the input statement for column-level lineage generation. - select_statement = _try_extract_select(statement) + try: + select_statement = _try_extract_select(statement) + except Exception as e: + logger.debug(f"Failed to extract select from statement: {e}", exc_info=True) + debug_info.column_error = e + select_statement = None # Generate column-level lineage. column_lineage: Optional[List[_ColumnLineageInfo]] = None try: - column_lineage = _column_level_lineage( - select_statement, - dialect=dialect, - input_tables=table_name_schema_mapping, - output_table=downstream_table, - default_db=default_db, - default_schema=default_schema, - ) + if select_statement is not None: + column_lineage = _column_level_lineage( + select_statement, + dialect=dialect, + input_tables=table_name_schema_mapping, + output_table=downstream_table, + default_db=default_db, + default_schema=default_schema, + ) except UnsupportedStatementTypeError as e: # Inject details about the outer statement type too. e.args = (f"{e.args[0]} (outer statement type: {type(statement)})",) diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_from_table.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_from_table.json new file mode 100644 index 00000000000000..e2baa34e7fe287 --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_from_table.json @@ -0,0 +1,56 @@ +{ + "query_type": "UPDATE", + "in_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table2,PROD)" + ], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)", + "column": "col1", + "column_type": { + "type": { + "com.linkedin.pegasus2avro.schema.StringType": {} + } + }, + "native_column_type": "VARCHAR" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)", + "column": "col1" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)", + "column": "col2" + } + ] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)", + "column": "col2", + "column_type": { + "type": { + "com.linkedin.pegasus2avro.schema.StringType": {} + } + }, + "native_column_type": "VARCHAR" + }, + "upstreams": [ + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)", + "column": "col1" + }, + { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table2,PROD)", + "column": "col2" + } + ] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_hardcoded.json b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_hardcoded.json new file mode 100644 index 00000000000000..b41ed61b37cdbd --- /dev/null +++ b/metadata-ingestion/tests/unit/sql_parsing/goldens/test_snowflake_update_hardcoded.json @@ -0,0 +1,35 @@ +{ + "query_type": "UPDATE", + "in_tables": [], + "out_tables": [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)" + ], + "column_lineage": [ + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "orderkey", + "column_type": { + "type": { + "com.linkedin.pegasus2avro.schema.NumberType": {} + } + }, + "native_column_type": "INT" + }, + "upstreams": [] + }, + { + "downstream": { + "table": "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)", + "column": "totalprice", + "column_type": { + "type": { + "com.linkedin.pegasus2avro.schema.NumberType": {} + } + }, + "native_column_type": "INT" + }, + "upstreams": [] + } + ] +} \ No newline at end of file diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py index 059add8db67e48..dfc5b486abd35f 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_lineage.py @@ -3,6 +3,7 @@ import pytest from datahub.testing.check_sql_parser_result import assert_sql_result +from datahub.utilities.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT RESOURCE_DIR = pathlib.Path(__file__).parent / "goldens" @@ -672,3 +673,98 @@ def test_teradata_default_normalization(): }, expected_file=RESOURCE_DIR / "test_teradata_default_normalization.json", ) + + +def test_snowflake_update_hardcoded(): + assert_sql_result( + """ +UPDATE snowflake_sample_data.tpch_sf1.orders +SET orderkey = 1, totalprice = 2 +WHERE orderkey = 3 +""", + dialect="snowflake", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,snowflake_sample_data.tpch_sf1.orders,PROD)": { + "orderkey": "NUMBER(38,0)", + "totalprice": "NUMBER(12,2)", + }, + }, + expected_file=RESOURCE_DIR / "test_snowflake_update_hardcoded.json", + ) + + +def test_update_from_select(): + assert _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT == {"returning", "this"} + + +def test_snowflake_update_from_table(): + # Can create these tables with the following SQL: + """ + -- Create or replace my_table + CREATE OR REPLACE TABLE my_table ( + id INT IDENTITY PRIMARY KEY, + col1 VARCHAR(50), + col2 VARCHAR(50) + ); + + -- Create or replace table1 + CREATE OR REPLACE TABLE table1 ( + id INT IDENTITY PRIMARY KEY, + col1 VARCHAR(50), + col2 VARCHAR(50) + ); + + -- Create or replace table2 + CREATE OR REPLACE TABLE table2 ( + id INT IDENTITY PRIMARY KEY, + col2 VARCHAR(50) + ); + + -- Insert data into my_table + INSERT INTO my_table (col1, col2) + VALUES ('foo', 'bar'), + ('baz', 'qux'); + + -- Insert data into table1 + INSERT INTO table1 (col1, col2) + VALUES ('foo', 'bar'), + ('baz', 'qux'); + + -- Insert data into table2 + INSERT INTO table2 (col2) + VALUES ('bar'), + ('qux'); + """ + + assert_sql_result( + """ +UPDATE my_table +SET + col1 = t1.col1 || t1.col2, + col2 = t1.col1 || t2.col2 +FROM table1 t1 +JOIN table2 t2 ON t1.id = t2.id +WHERE my_table.id = t1.id; +""", + dialect="snowflake", + default_db="my_db", + default_schema="my_schema", + schemas={ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)": { + "id": "NUMBER(38,0)", + "col1": "VARCHAR(16777216)", + "col2": "VARCHAR(16777216)", + }, + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table1,PROD)": { + "id": "NUMBER(38,0)", + "col1": "VARCHAR(16777216)", + "col2": "VARCHAR(16777216)", + }, + "urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.table2,PROD)": { + "id": "NUMBER(38,0)", + "col1": "VARCHAR(16777216)", + "col2": "VARCHAR(16777216)", + }, + }, + expected_file=RESOURCE_DIR / "test_snowflake_update_from_table.json", + )