diff --git a/.gitignore b/.gitignore index ac41572..f1d1791 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,4 @@ dmypy.json # Pyre type checker .pyre/ +.aider* diff --git a/setup.py b/setup.py index 1bcf0bb..2fc5ff1 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ ] EXTRAS = { "dev": [ - "black>=24.10.0", + "black>=24.8.0", "mypy>=1.14.1", "pytest>=8.3.4", "python-dotenv>=1.0.1", diff --git a/sqlalchemy_kusto/dbapi.py b/sqlalchemy_kusto/dbapi.py index 1b697a4..e8f2382 100644 --- a/sqlalchemy_kusto/dbapi.py +++ b/sqlalchemy_kusto/dbapi.py @@ -1,12 +1,12 @@ from collections import namedtuple from typing import Any -from azure.identity import WorkloadIdentityCredential from azure.kusto.data import ( ClientRequestProperties, KustoClient, KustoConnectionStringBuilder, ) +from azure.identity import DefaultAzureCredential from azure.kusto.data._models import KustoResultColumn from azure.kusto.data.exceptions import KustoAuthenticationError, KustoServiceError @@ -39,12 +39,12 @@ def connect( cluster: str, database: str, msi: bool = False, - workload_identity: bool = False, user_msi: str | None = None, + workload_identity: bool = False, azure_ad_client_id: str | None = None, azure_ad_client_secret: str | None = None, azure_ad_tenant_id: str | None = None, -): +): # pylint: disable=too-many-positional-arguments """Return a connection to the database.""" return Connection( cluster, @@ -71,7 +71,7 @@ def __init__( azure_ad_client_id: str | None = None, azure_ad_client_secret: str | None = None, azure_ad_tenant_id: str | None = None, - ): + ): # pylint: disable=too-many-positional-arguments self.closed = False self.cursors: list[Cursor] = [] kcsb = None @@ -85,9 +85,8 @@ def __init__( authority_id=azure_ad_tenant_id, ) elif workload_identity: - # Workload Identity kcsb = KustoConnectionStringBuilder.with_azure_token_credential( - cluster, WorkloadIdentityCredential() + cluster, DefaultAzureCredential() ) elif msi: # Managed Service Identity (MSI) diff --git a/sqlalchemy_kusto/dialect_kql.py b/sqlalchemy_kusto/dialect_kql.py index efbec5e..7b5f4b6 100644 --- a/sqlalchemy_kusto/dialect_kql.py +++ b/sqlalchemy_kusto/dialect_kql.py @@ -1,7 +1,7 @@ import logging import re -from sqlalchemy import Column, exc +from sqlalchemy import Column, exc, sql from sqlalchemy.sql import compiler, operators, selectable from sqlalchemy.sql.compiler import OPERATORS @@ -12,11 +12,22 @@ aggregates_sql_to_kql = { "count(*)": "count()", + "count": "count", + "count(distinct": "dcount", + "count(distinct(": "dcount", + "count_distinct": "dcount", + "sum": "sum", + "avg": "avg", + "min": "min", + "max": "max", } +AGGREGATE_PATTERN = ( + r"(\w+)\s*\(\s*(DISTINCT|distinct\s*)?\(?\s*(\*|\[?\"?\'?\w+\"?\]?)\s*\)?\s*\)" +) class UniversalSet: - def __contains__(self, item) -> bool: + def __contains__(self, item): return True @@ -24,17 +35,17 @@ class KustoKqlIdentifierPreparer(compiler.IdentifierPreparer): # We want to quote all table and column names to prevent unconventional names usage reserved_words = UniversalSet() - def __init__(self, dialect, **kw) -> None: + def __init__(self, dialect, **kw): super().__init__(dialect, initial_quote='["', final_quote='"]', **kw) class KustoKqlCompiler(compiler.SQLCompiler): OPERATORS[operators.and_] = " and " - delete_extra_from_clause = None update_from_clause = None visit_empty_set_expr = None visit_sequence = None + sort_with_clause_parts = 2 def visit_select( self, @@ -47,14 +58,12 @@ def visit_select( lateral=False, from_linter=None, **kwargs, - ) -> str: + ): # pylint: disable=too-many-positional-arguments logger.debug("Incoming query: %s", select_stmt) - if len(select_stmt.get_final_froms()) != 1: raise NotSupportedError( 'Only single "select from" query is supported in kql compiler' ) - compiled_query_lines = [] from_object = select_stmt.get_final_froms()[0] @@ -77,62 +86,271 @@ def visit_select( self._convert_schema_in_statement(from_object.text) ) + projections_parts_dict = self._get_projection_or_summarize(select_stmt) + if "extend" in projections_parts_dict: + compiled_query_lines.append(projections_parts_dict.pop("extend")) + if select_stmt._whereclause is not None: where_clause = select_stmt._whereclause._compiler_dispatch(self, **kwargs) if where_clause: - compiled_query_lines.append(f"| where {where_clause}") + converted_where_clause = self._sql_to_kql_where(where_clause) + compiled_query_lines.append(f"| where {converted_where_clause}") - projections = self._get_projection_or_summarize(select_stmt) - if projections: - compiled_query_lines.append(projections) + for statement_part in projections_parts_dict.values(): + if statement_part: + compiled_query_lines.append(statement_part) - if select_stmt._limit_clause is not None: + if select_stmt._limit_clause is not None: # pylint: disable=protected-access kwargs["literal_execute"] = True compiled_query_lines.append( f"| take {self.process(select_stmt._limit_clause, **kwargs)}" - ) - + ) # pylint: disable=protected-access compiled_query_lines = list(filter(None, compiled_query_lines)) - compiled_query = "\n".join(compiled_query_lines) - logger.debug("Compiled query: %s", compiled_query) + logger.warning("Compiled query: %s", compiled_query) return compiled_query def limit_clause(self, select, **kw): return "" - def _get_projection_or_summarize(self, select: selectable.Select) -> str: + def _get_projection_or_summarize(self, select: selectable.Select) -> dict[str, str]: """Builds the ending part of the query either project or summarize.""" columns = select.inner_columns + group_by_cols = select._group_by_clauses # pylint: disable=protected-access + order_by_cols = select._order_by_clauses # pylint: disable=protected-access + summarize_statement = "" + extend_statement = "" + project_statement = "" + has_aggregates = False + # The following is the logic + # With Columns : + # - Do we have a group by clause ? --Yes---> Do we have aggregate columns ? --Yes--> Summarize new column(s) + # | | with by clause + # N N --> Add to projection + # | + # | + # - Do the columns have aliases ? --Yes---> Extend with aliases + # | + # N---> Add to projection if columns is not None: - column_labels = [] - is_summarize = False + summarize_columns = set() + extend_columns = set() + projection_columns = [] for column in [c for c in columns if c.name != "*"]: column_name, column_alias = self._extract_column_name_and_alias(column) + column_alias = self._escape_and_quote_columns(column_alias, True) + # Do we have a group by clause ? + # Do we have aggregate columns ? + kql_agg = self._extract_maybe_agg_column_parts(column_name) + if kql_agg: + has_aggregates = True + summarize_columns.add( + self._build_column_projection(kql_agg, column_alias) + ) + # No group by clause + # Do the columns have aliases ? + # Add additional and to handle case where : SELECT column_name as column_name + elif column_alias and column_alias != column_name: + extend_columns.add( + self._build_column_projection(column_name, column_alias, True) + ) + if column_alias: + projection_columns.append( + self._escape_and_quote_columns(column_alias, True) + ) + else: + projection_columns.append( + self._escape_and_quote_columns(column_name) + ) + # group by columns + by_columns = self._group_by(group_by_cols) + if has_aggregates or bool( + by_columns + ): # Summarize can happen with or without aggregate being created + summarize_statement = f"| summarize {', '.join(summarize_columns)} " + if by_columns: + summarize_statement = ( + f"{summarize_statement} by {', '.join(by_columns)}" + ) + if extend_columns: + extend_statement = f"| extend {', '.join(sorted(extend_columns))}" + project_statement = ( + f"| project {', '.join(projection_columns)}" + if projection_columns + else "" + ) - if column_name in aggregates_sql_to_kql: - is_summarize = True - column_labels.append( - self._build_column_projection( - aggregates_sql_to_kql[column_name], column_alias - ) + unwrapped_order_by = self._get_order_by(order_by_cols) + sort_statement = ( + f"| order by {', '.join(unwrapped_order_by)}" if unwrapped_order_by else "" + ) + return { + "extend": extend_statement, + "summarize": summarize_statement, + "project": project_statement, + "sort": sort_statement, + } + + @staticmethod + def _extract_maybe_agg_column_parts(column_name) -> str | None: + match_agg_cols = re.match(AGGREGATE_PATTERN, column_name, re.IGNORECASE) + if match_agg_cols and match_agg_cols.groups(): + # Check if the aggregate function is count_distinct. This is case from superset + # where we can use count(distinct or count_distinct) + aggregate_func, distinct_keyword, agg_column_name = match_agg_cols.groups() + is_distinct = ( + bool(distinct_keyword) or aggregate_func.casefold() == "count_distinct" + ) + kql_agg = KustoKqlCompiler._sql_to_kql_aggregate( + aggregate_func.lower(), agg_column_name, is_distinct + ) + return kql_agg + return None + + def _get_order_by(self, order_by_cols): + unwrapped_order_by = [] + for elem in order_by_cols: + if isinstance(elem, sql.elements._label_reference): + nested_element = elem.element + unwrapped_order_by.append( + f"{self._escape_and_quote_columns(nested_element._order_by_label_element.name,is_alias=True)} " + f"{'desc' if (nested_element.modifier is operators.desc_op) else 'asc'}" + ) + elif isinstance(elem, sql.elements.TextClause): + sort_parts = elem.text.split() + if len(sort_parts) == self.sort_with_clause_parts: + unwrapped_order_by.append( + f"{self._escape_and_quote_columns(sort_parts[0],is_alias=True)} {sort_parts[1].lower()}" + ) + elif len(sort_parts) == 1: + unwrapped_order_by.append( + self._escape_and_quote_columns(sort_parts[0], is_alias=True) ) else: - column_labels.append( - self._build_column_projection(column_name, column_alias) + unwrapped_order_by.append( + elem.text.replace(" ASC", " asc").replace(" DESC", " desc") ) + else: + logger.warning( + "Unsupported order by clause: %s of type %s", elem, type(elem) + ) + return unwrapped_order_by + + def _group_by(self, group_by_cols): + by_columns = set() + for column in group_by_cols: + column_name, column_alias = self._extract_column_name_and_alias(column) + if column_alias: + by_columns.add(self._escape_and_quote_columns(column_alias)) + else: + by_columns.add(self._escape_and_quote_columns(column_name)) + return by_columns - if column_labels: - projection_type = "summarize" if is_summarize else "project" - return f"| {projection_type} {', '.join(column_labels)}" - return "" + @staticmethod + def _escape_and_quote_columns(name: str | None, is_alias=False) -> str: + if name is None: + return "" + name = name.strip() + if KustoKqlCompiler._is_kql_function(name) and not is_alias: + return name + if name.startswith('"') and name.endswith('"'): + name = name[1:-1] + # First, check if the name is already wrapped in ["ColumnName"] (escaped format) + if name.startswith('["') and name.endswith('"]'): + return name # Return as is if already properly escaped + # Remove surrounding spaces + # Handle mathematical operations (wrap only the column part before operators) + # Find the position of the first operator or space that separates the column name + for operator in ["/", "+", "-", "*"]: + if operator in name: + # Split the name at the first operator and wrap the left part + parts = name.split(operator, 1) + # Remove quotes if they exist at the edges + col_part = parts[0].strip() + if col_part.startswith('"') and col_part.endswith('"'): + return f'["{col_part[1:-1].strip()}"] {operator} {parts[1].strip()}' + return f'["{col_part}"] {operator} {parts[1].strip()}' # Wrap the column part + # No operators found, just wrap the entire name + return f'["{name}"]' + + @staticmethod + def _sql_to_kql_where(where_clause: str) -> str: + where_clause = where_clause.strip().replace("\n", "") + # Handle 'IS NULL' and 'IS NOT NULL' -> KQL equivalents + where_clause = re.sub( + r'(\["[^\]]+"\])\s*IS NULL', + r"isnull(\1)", + where_clause, + flags=re.IGNORECASE, + ) # IS NULL -> isnull(["FieldName"]) + where_clause = re.sub( + r'(\["[^\]]+"\])\s*IS NOT NULL', + r"isnotnull(\1)", + where_clause, + flags=re.IGNORECASE, + ) # IS NOT NULL -> isnotnull(["FieldName"]) + # Handle comparison operators + # Change '=' to '==' for equality comparisons + where_clause = re.sub( + r"(?<=[^=])=(?=\s|$|[^=])", r"==", where_clause, flags=re.IGNORECASE + ) + # Remove spaces in < = and > = operators + where_clause = re.sub(r"\s*<\s*=\s*", "<=", where_clause, flags=re.IGNORECASE) + where_clause = re.sub(r"\s*>\s*=\s*", ">=", where_clause, flags=re.IGNORECASE) + where_clause = where_clause.replace(">==", ">=") + where_clause = where_clause.replace("<==", "<=") + where_clause = re.sub( + r"(\s)(<>|!=)\s*", r" \2 ", where_clause, flags=re.IGNORECASE + ) # Handle '!=' and '<>' + where_clause = re.sub( + r"(\s)(<|<=|>|>=)\s*", r" \2 ", where_clause, flags=re.IGNORECASE + ) # Comparison operators: <, <=, >, >= + # Step 3: Handle 'LIKE' -> 'has' for substring matching + where_clause = re.sub( + r"(\s)LIKE\s*", r"\1has ", where_clause, flags=re.IGNORECASE + ) # Replace LIKE with has + # Step 4: Handle 'IN' and 'NOT IN' operators (with lists inside parentheses) + # We need to correctly handle multiple spaces around IN/NOT IN and lists inside parentheses + where_clause = re.sub( + r"(\s)NOT IN\s*\(([^)]+)\)", + r"\1not in (\2)", + where_clause, + flags=re.IGNORECASE, + ) # NOT IN operator (list of values) + where_clause = re.sub( + r"(\s)IN\s*\(([^)]+)\)", r"\1in (\2)", where_clause, flags=re.IGNORECASE + ) # IN operator (list of values) + # Handle BETWEEN operator (if needed) + + where_clause = re.sub( + r"(\w+|\[\"[A-Za-z0-9_]+\"\]) (BETWEEN|between) (\d) (AND|and) (\d)", + r"\1 between (\3..\5)", + where_clause, + flags=re.IGNORECASE, + ) + where_clause = re.sub( + r"(\w+) (BETWEEN|between) (\d) (AND|and) (\d)", + r"\1 between (\3..\5)", + where_clause, + flags=re.IGNORECASE, + ) + # Handle logical operators 'AND' and 'OR' to ensure the conditions are preserved + # Replace AND with 'and' in KQL + where_clause = re.sub(r"\s+AND\s+", r" and ", where_clause, flags=re.IGNORECASE) + # Replace OR with 'or' in KQL + where_clause = re.sub(r"\s+OR\s+", r" or ", where_clause, flags=re.IGNORECASE) + return where_clause + + @staticmethod + def _is_kql_function(name: str) -> bool: + pattern = r"^[a-zA-Z_][a-zA-Z0-9_]*\s*\(" + return bool(re.match(pattern, name)) def _get_most_inner_element(self, clause): """Finds the most nested element in clause.""" inner_element = getattr(clause, "element", None) if inner_element is not None: return self._get_most_inner_element(inner_element) - return clause @staticmethod @@ -150,15 +368,22 @@ def _extract_let_statements(clause) -> tuple[str, list[str]]: @staticmethod def _extract_column_name_and_alias(column: Column) -> tuple[str, str | None]: if hasattr(column, "element"): - return column.element.name, column.name - - return column.name, None + return str(column.element), column.name + if hasattr(column, "name"): + return str(column.name), None + return str(column), None @staticmethod def _build_column_projection( - column_name: str, column_alias: str | None = None + column_name: str, column_alias: str | None = None, is_extend: bool = False ) -> str: """Generates column alias semantic for project statement.""" + if is_extend: + return ( + f"{column_alias} = {KustoKqlCompiler._escape_and_quote_columns(column_name)}" + if column_alias + else KustoKqlCompiler._escape_and_quote_columns(column_name) + ) return f"{column_alias} = {column_name}" if column_alias else column_name @staticmethod @@ -193,6 +418,38 @@ def _convert_schema_in_statement(query: str) -> str: original, f'database("{unquoted_schema}").["{unquoted_table}"]', 1 ) + @staticmethod + def _sql_to_kql_aggregate( + sql_agg: str, column_name: str | None = None, is_distinct: bool = False + ) -> str | None: + """ + Converts SQL aggregate function to KQL equivalent. + If a column name is provided, applies it to the aggregate. + """ + has_column = column_name is not None and column_name.strip() != "" + column_name_escaped = ( + KustoKqlCompiler._escape_and_quote_columns(column_name) + if has_column + else "" + ) + return_value = None + # The count function is a special case because it can be used with or without a column name + # We can also use it in count(Distinct column_name) format. This has to be handled separately + if sql_agg and ("count" in sql_agg or "COUNT" in sql_agg): + if "*" in sql_agg or column_name == "*": + return_value = aggregates_sql_to_kql["count(*)"] + elif is_distinct: + return_value = f"dcount({column_name_escaped})" + else: + return_value = f"count({column_name_escaped})" + if return_value: + return return_value + # Other summarize operators have to be looked up + aggregate_function = aggregates_sql_to_kql.get(sql_agg.lower().split("(")[0]) + if aggregate_function: + return_value = f"{aggregate_function}({column_name_escaped})" + return return_value + class KustoKqlHttpsDialect(KustoBaseDialect): name = "kustokql" diff --git a/tests/integration/test_dbapi.py b/tests/integration/test_dbapi.py index 94c7d39..a5c577e 100644 --- a/tests/integration/test_dbapi.py +++ b/tests/integration/test_dbapi.py @@ -18,8 +18,8 @@ def test_execute() -> None: KUSTO_URL, DATABASE, False, - False, - None, + user_msi=None, + workload_identity=False, azure_ad_client_id=AZURE_AD_CLIENT_ID, azure_ad_client_secret=AZURE_AD_CLIENT_SECRET, azure_ad_tenant_id=AZURE_AD_TENANT_ID, diff --git a/tests/integration/test_dialect_kql.py b/tests/integration/test_dialect_kql.py new file mode 100644 index 0000000..f25bc4e --- /dev/null +++ b/tests/integration/test_dialect_kql.py @@ -0,0 +1,173 @@ +import logging +import uuid + +import pytest +from azure.kusto.data import ( + ClientRequestProperties, + KustoClient, + KustoConnectionStringBuilder, +) +from sqlalchemy import Column, MetaData, String, Table, create_engine, func, text +from sqlalchemy.orm import sessionmaker + +from tests.integration.conftest import ( + AZURE_AD_CLIENT_ID, + AZURE_AD_CLIENT_SECRET, + AZURE_AD_TENANT_ID, + DATABASE, + KUSTO_KQL_ALCHEMY_URL, + KUSTO_URL, +) + +logger = logging.getLogger(__name__) + +kql_engine = create_engine( + f"{KUSTO_KQL_ALCHEMY_URL}/{DATABASE}?" + f"msi=False&azure_ad_client_id={AZURE_AD_CLIENT_ID}&" + f"azure_ad_client_secret={AZURE_AD_CLIENT_SECRET}&" + f"azure_ad_tenant_id={AZURE_AD_TENANT_ID}" +) + +Session = sessionmaker(bind=kql_engine) +session = Session() +metadata = MetaData() + + +def test_group_by(temp_table_name): + table = Table( + temp_table_name, + metadata, + ) + query = ( + session.query(func.count(text("Id")).label("tag_count")) + .add_columns(Column("Text", String)) + .select_from(table) + .group_by(text("Text")) + .order_by("tag_count") + ) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + with kql_engine.connect() as connection: + # SELECT count(distinct (case when Id%2=0 THEN 'Even' end)) as tag_count FROM {temp_table_name} + # convert the above query to using alchemy + result = connection.execute(text(query_compiled)) + # There is Even and Empty only for this test, 2 distinct values + assert {(x[0], x[1]) for x in result.fetchall()} == { + (5, "value_1"), + (4, "value_0"), + } + + +# Test without group +def test_count_by(temp_table_name): + # Convert the query: SELECT count(distinct (case when Id%2=0 THEN 'Even' end)) as tag_count FROM {temp_table_name} + table = Table( + temp_table_name, + metadata, + ) + query = session.query(func.count(text("Id")).label("tag_count")).select_from(table) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + # There is Even and Empty only for this test, 2 distinct values + assert {(x[0]) for x in result.fetchall()} == {9} + + +def test_distinct_counts_by(temp_table_name): + # Convert to : SELECT count(distinct (case when Id%2=0 THEN 'Even' end)) as tag_count FROM {temp_table_name} + table = Table( + temp_table_name, + metadata, + ) + query = session.query( + func.count(func.distinct(text("Text"))).label("tag_count") + ).select_from(table) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + # There is Even and Empty only for this test, 2 distinct values + assert {(x[0]) for x in result.fetchall()} == {2} + + +@pytest.mark.parametrize( + ("f", "label", "expected"), + [ + pytest.param(func.min(text("Id")), "Min", 1), + pytest.param(func.max(text("Id")), "Max", 9), + pytest.param(func.sum(text("Id")), "Sum", 45), + ], +) +def test_all_group_ops(f, label, expected, temp_table_name): + # Convert : SELECT count(distinct (case when Id%2=0 THEN 'Even' end)) as tag_count FROM {temp_table_name} + table = Table( + temp_table_name, + metadata, + ) + query = session.query(f.label(label)).select_from(table) + query_compiled = str(query.statement.compile(kql_engine)).replace("\n", "") + with kql_engine.connect() as connection: + result = connection.execute(text(query_compiled)) + # There is Even and Empty only for this test, 2 distinct values + assert {(x[0]) for x in result.fetchall()} == {expected} + + +def get_kcsb(): + return ( + KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL) + if not AZURE_AD_CLIENT_ID + and not AZURE_AD_CLIENT_SECRET + and not AZURE_AD_TENANT_ID + else KustoConnectionStringBuilder.with_aad_application_key_authentication( + KUSTO_URL, AZURE_AD_CLIENT_ID, AZURE_AD_CLIENT_SECRET, AZURE_AD_TENANT_ID + ) + ) + + +def _create_temp_table(table_name: str): + client = KustoClient(get_kcsb()) + client.execute( + DATABASE, + f".create table {table_name}(Id: int, Text: string)", + ClientRequestProperties(), + ) + + +def _create_temp_fn(fn_name: str): + client = KustoClient(get_kcsb()) + client.execute( + DATABASE, + f".create function {fn_name}() {{ print now()}}", + ClientRequestProperties(), + ) + + +def _ingest_data_to_table(table_name: str): + client = KustoClient(get_kcsb()) + data_to_ingest = {i: "value_" + str(i % 2) for i in range(1, 10)} + str_data = "\n".join("{},{}".format(*p) for p in data_to_ingest.items()) + ingest_query = f""".ingest inline into table {table_name} <| + {str_data}""" + client.execute(DATABASE, ingest_query, ClientRequestProperties()) + + +def _drop_table(table_name: str): + client = KustoClient(get_kcsb()) + + _ = client.execute(DATABASE, f".drop table {table_name}", ClientRequestProperties()) + _ = client.execute( + DATABASE, f".drop function {table_name}_fn", ClientRequestProperties() + ) + + +@pytest.fixture +def temp_table_name(): + return "_temp_" + uuid.uuid4().hex + "_kql" + + +@pytest.fixture(autouse=True) +def run_around_tests(temp_table_name): + _create_temp_table(temp_table_name) + _create_temp_fn(f"{temp_table_name}_fn") + _ingest_data_to_table(temp_table_name) + # A test function will be run at this point + yield temp_table_name + _drop_table(temp_table_name) diff --git a/tests/integration/test_dialect_sql.py b/tests/integration/test_dialect_sql.py index 4f26563..ed275df 100644 --- a/tests/integration/test_dialect_sql.py +++ b/tests/integration/test_dialect_sql.py @@ -1,5 +1,3 @@ -from collections.abc import Generator -from typing import Any import uuid import pytest @@ -27,31 +25,31 @@ ) -def test_ping() -> None: +def test_ping(): conn = engine.connect() result = engine.dialect.do_ping(conn) assert result is True -def test_get_table_names(temp_table_name: str) -> None: +def test_get_table_names(temp_table_name): conn = engine.connect() result = engine.dialect.get_table_names(conn) assert temp_table_name in result -def test_get_view_names(temp_table_name: str) -> None: +def test_get_view_names(temp_table_name): conn = engine.connect() result = engine.dialect.get_view_names(conn) assert f"{temp_table_name}_fn" in result -def test_get_columns(temp_table_name: str) -> None: +def test_get_columns(temp_table_name): conn = engine.connect() columns_result = engine.dialect.get_columns(conn, temp_table_name) assert {"Id", "Text"} == {c["name"] for c in columns_result} -def test_fetch_one(temp_table_name: str) -> None: +def test_fetch_one(temp_table_name): engine.connect() result = engine.execute(f"select top 2 * from {temp_table_name} order by Id") assert result.fetchone() == (1, "value_1") @@ -59,7 +57,7 @@ def test_fetch_one(temp_table_name: str) -> None: assert result.fetchone() is None -def test_fetch_many(temp_table_name: str) -> None: +def test_fetch_many(temp_table_name): engine.connect() result = engine.execute(f"select top 5 * from {temp_table_name} order by Id") @@ -74,7 +72,7 @@ def test_fetch_many(temp_table_name: str) -> None: } -def test_fetch_all(temp_table_name: str) -> None: +def test_fetch_all(temp_table_name): engine.connect() result = engine.execute(f"select top 3 * from {temp_table_name} order by Id") assert {(x[0], x[1]) for x in result.fetchall()} == { @@ -84,8 +82,8 @@ def test_fetch_all(temp_table_name: str) -> None: } -def test_limit(temp_table_name: str) -> None: - limit = 5 +def test_limit(temp_table_name): + limit_rec_count = 5 stream = Table( temp_table_name, MetaData(), @@ -93,15 +91,15 @@ def test_limit(temp_table_name: str) -> None: Column("Text", String), ) - query = stream.select().limit(limit) + query = stream.select().limit(limit_rec_count) engine.connect() result = engine.execute(query) result_length = len(result.fetchall()) - assert result_length == limit + assert result_length == limit_rec_count -def get_kcsb() -> Any: +def get_kcsb(): return ( KustoConnectionStringBuilder.with_az_cli_authentication(KUSTO_URL) if not AZURE_AD_CLIENT_ID @@ -113,7 +111,7 @@ def get_kcsb() -> Any: ) -def _create_temp_table(table_name: str) -> None: +def _create_temp_table(table_name: str): client = KustoClient(get_kcsb()) client.execute( DATABASE, @@ -122,7 +120,7 @@ def _create_temp_table(table_name: str) -> None: ) -def _create_temp_fn(fn_name: str) -> None: +def _create_temp_fn(fn_name: str): client = KustoClient(get_kcsb()) client.execute( DATABASE, @@ -131,7 +129,7 @@ def _create_temp_fn(fn_name: str) -> None: ) -def _ingest_data_to_table(table_name: str) -> None: +def _ingest_data_to_table(table_name: str): client = KustoClient(get_kcsb()) data_to_ingest = {i: "value_" + str(i) for i in range(1, 10)} str_data = "\n".join("{},{}".format(*p) for p in data_to_ingest.items()) @@ -140,7 +138,7 @@ def _ingest_data_to_table(table_name: str) -> None: client.execute(DATABASE, ingest_query, ClientRequestProperties()) -def _drop_table(table_name: str) -> None: +def _drop_table(table_name: str): client = KustoClient(get_kcsb()) _ = client.execute(DATABASE, f".drop table {table_name}", ClientRequestProperties()) @@ -150,12 +148,12 @@ def _drop_table(table_name: str) -> None: @pytest.fixture -def temp_table_name() -> str: +def temp_table_name(): return "_temp_" + uuid.uuid4().hex @pytest.fixture(autouse=True) -def run_around_tests(temp_table_name: str) -> Generator[str, None, None]: +def run_around_tests(temp_table_name): _create_temp_table(temp_table_name) _create_temp_fn(f"{temp_table_name}_fn") _ingest_data_to_table(temp_table_name) diff --git a/tests/unit/test_dialect_kql.py b/tests/unit/test_dialect_kql.py index 71824c5..3138951 100644 --- a/tests/unit/test_dialect_kql.py +++ b/tests/unit/test_dialect_kql.py @@ -2,21 +2,25 @@ import sqlalchemy as sa from sqlalchemy import ( Column, + Integer, MetaData, String, Table, column, create_engine, + distinct, literal_column, select, text, ) from sqlalchemy.sql.selectable import TextAsFrom +from sqlalchemy_kusto.dialect_kql import KustoKqlCompiler + engine = create_engine("kustokql+https://localhost/testdb") -def test_compiler_with_projection() -> None: +def test_compiler_with_projection(): statement_str = "logs | take 10" stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table") query = sa.select( @@ -32,16 +36,17 @@ def test_compiler_with_projection() -> None: query_compiled = str(query.compile(engine)).replace("\n", "") query_expected = ( - 'let virtual_table = (["logs"] | take 10);' - "virtual_table" - "| project id = Id, tId = TypeId, Type" + 'let virtual_table = (["logs"] ' + "| take 10);virtual_table" + '| extend ["id"] = ["Id"], ["tId"] = ["TypeId"]' + '| project ["id"], ["tId"], ["Type"]' "| take __[POSTCOMPILE_param_1]" ) assert query_compiled == query_expected -def test_compiler_with_star() -> None: +def test_compiler_with_star(): statement_str = "logs | take 10" stmt = TextAsFrom(sa.text(statement_str), []).alias("virtual_table") query = sa.select( @@ -50,18 +55,16 @@ def test_compiler_with_star() -> None: ) query = query.select_from(stmt) query = query.limit(10) - query_compiled = str(query.compile(engine)).replace("\n", "") query_expected = ( 'let virtual_table = (["logs"] | take 10);' "virtual_table" "| take __[POSTCOMPILE_param_1]" ) - assert query_compiled == query_expected -def test_select_from_text() -> None: +def test_select_from_text(): query = ( select([column("Field1"), column("Field2")]) .select_from(text("logs")) @@ -70,12 +73,206 @@ def test_select_from_text() -> None: query_compiled = str( query.compile(engine, compile_kwargs={"literal_binds": True}) ).replace("\n", "") - query_expected = '["logs"]' "| project Field1, Field2" "| take 100" + query_expected = '["logs"]| project ["Field1"], ["Field2"]| take 100' + assert query_compiled == query_expected + +@pytest.mark.parametrize( + ("f", "expected"), + [ + pytest.param( + Column("Field1", String).in_(["1", "One"]), """["Field1"] in ('1', 'One')""" + ), + pytest.param( + Column("Field1", String).notin_(["1", "One"]), + """(["Field1"] not in ('1', 'One'))""", + ), + pytest.param(text("Field1 = '1'"), """Field1 == '1'"""), + pytest.param( + Column("Field2", Integer).between(2, 4), """["Field2"] between (2..4)""" + ), + pytest.param(Column("Field2", Integer).is_(None), """isnull(["Field2"])"""), + pytest.param( + Column("Field2", Integer).isnot(None), """isnotnull(["Field2"])""" + ), + pytest.param( + (Column("Field2", Integer).isnot(None)).__and__( + Column("Field1", String).notin_(["1", "One"]) + ), + """isnotnull(["Field2"]) and (["Field1"] not in ('1', 'One'))""", + ), + pytest.param( + (Column("Field2", Integer).isnot(None)).__or__( + Column("Field1", String).notin_(["1", "One"]) + ), + """isnotnull(["Field2"]) or (["Field1"] not in ('1', 'One'))""", + ), + ], +) +def test_where_predicates(f, expected): + query = ( + select([column("Field1"), column("Field2")]).select_from(text("logs")).where(f) + ).limit(100) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + query_expected = ( + f"""["logs"]| where {expected}| project ["Field1"], ["Field2"]| take 100""" + ) assert query_compiled == query_expected -def test_use_table() -> None: +def test_group_by_text(): + # create a query from select_query_text creating clause + event_col = literal_column('"EventInfo_Time" / time(1d)').label("EventInfo_Time") + active_users_col = literal_column("ActiveUsers").label("ActiveUserMetric") + query = ( + select([event_col, active_users_col]) + .select_from(text("ActiveUsersLastMonth")) + .group_by(literal_column('"EventInfo_Time" / time(1d)')) + .order_by(text("ActiveUserMetric DESC")) + ) + + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + # raw query text from query + query_expected = ( + '["ActiveUsersLastMonth"]| extend ["ActiveUserMetric"] = ["ActiveUsers"], ' + '["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '| summarize by ["EventInfo_Time"] / time(1d)' + '| project ["EventInfo_Time"], ["ActiveUserMetric"]' + '| order by ["ActiveUserMetric"] desc' + ) + assert query_compiled == query_expected + + +def test_group_by_text_vaccine_dataset(): + # SQL: SELECT country_name AS country_name FROM superset."CovidVaccineData" GROUP BY country_name + # ORDER BY country_name ASC - this is a simple query to get distinct country names + query = ( + select([literal_column("country_name").label("country_name")]) + .select_from(text('superset."CovidVaccineData"')) + .group_by(literal_column("country_name")) + .order_by(text("country_name ASC")) + ) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + query_expected = ( + 'database("superset").["CovidVaccineData"]| ' + 'extend ["country_name"] = ["country_name"]| ' + 'summarize by ["country_name"]| ' + 'project ["country_name"]| order by ["country_name"] asc' + ) + assert query_compiled == query_expected + + +def test_is_kql_function(): + assert KustoKqlCompiler._is_kql_function( + """case(Size <= 3, "Small", + Size <= 10, "Medium", + "Large")""" + ) + assert KustoKqlCompiler._is_kql_function("""bin(time(16d), 7d)""") + assert KustoKqlCompiler._is_kql_function( + """iff((EventType in ("Heavy Rain", "Flash Flood", "Flood")), "Rain event", "Not rain event")""" + ) + + +def test_distinct_count_by_text(): + # create a query from select_query_text creating clause + # 'SELECT "EventInfo_Time" / time(1d) AS "EventInfo_Time", count(DISTINCT ActiveUsers) AS "DistinctUsers" + # FROM ActiveUsersLastMonth GROUP BY "EventInfo_Time" / time(1d) ORDER BY ActiveUserMetric DESC' + event_col = literal_column('"EventInfo_Time" / time(1d)').label("EventInfo_Time") + active_users_col = literal_column("ActiveUsers") + query = ( + select( + [ + event_col, + sa.func.count(distinct(active_users_col)).label("DistinctUsers"), + ] + ) + .select_from(text("ActiveUsersLastMonth")) + .group_by(literal_column('"EventInfo_Time" / time(1d)')) + .order_by(text("ActiveUserMetric DESC")) + ) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + # raw query text from query + query_expected = ( + '["ActiveUsersLastMonth"]' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| project ["EventInfo_Time"], ["DistinctUsers"]' + '| order by ["ActiveUserMetric"] desc' + ) + assert query_compiled == query_expected + + +def test_distinct_count_alt_by_text(): + # create a query from select_query_text creating clause + # 'SELECT "EventInfo_Time" / time(1d) AS "EventInfo_Time", count_distinct(ActiveUsers) AS "DistinctUsers" + # FROM ActiveUsersLastMonth GROUP BY "EventInfo_Time" / time(1d) ORDER BY ActiveUserMetric DESC' + event_col = literal_column("EventInfo_Time / time(1d)").label("EventInfo_Time") + active_users_col = literal_column("COUNT_DISTINCT(ActiveUsers)") + query = ( + select([event_col, active_users_col.label("DistinctUsers")]) + .select_from(text("ActiveUsersLastMonth")) + .group_by(literal_column("EventInfo_Time / time(1d)")) + .order_by(text("ActiveUserMetric DESC")) + ) + query_compiled = str( + query.compile(engine, compile_kwargs={"literal_binds": True}) + ).replace("\n", "") + # raw query text from query + query_expected = ( + '["ActiveUsersLastMonth"]' + '| extend ["EventInfo_Time"] = ["EventInfo_Time"] / time(1d)' + '| summarize ["DistinctUsers"] = dcount(["ActiveUsers"]) by ["EventInfo_Time"] / time(1d)' + '| project ["EventInfo_Time"], ["DistinctUsers"]' + '| order by ["ActiveUserMetric"] desc' + ) + + assert query_compiled == query_expected + + +def test_escape_and_quote_columns(): + assert ( + KustoKqlCompiler._escape_and_quote_columns("EventInfo_Time") + == '["EventInfo_Time"]' + ) + assert KustoKqlCompiler._escape_and_quote_columns('["UserId"]') == '["UserId"]' + assert ( + KustoKqlCompiler._escape_and_quote_columns("EventInfo_Time / time(1d)") + == '["EventInfo_Time"] / time(1d)' + ) + + +@pytest.mark.parametrize( + ("sql_aggregate", "column_name", "is_distinct", "expected_kql"), + [ + ("count(*)", None, False, "count()"), + ("count", "UserId", False, 'count(["UserId"])'), + ("count(distinct", "CustomerId", True, 'dcount(["CustomerId"])'), + ("count_distinct", "CustomerId", True, 'dcount(["CustomerId"])'), + ("sum", "Sales", False, 'sum(["Sales"])'), + ("avg", "ResponseTime", False, 'avg(["ResponseTime"])'), + ("AVG", "ResponseTime", False, 'avg(["ResponseTime"])'), + ("min", "Size", False, 'min(["Size"])'), + ("max", "Area", False, 'max(["Area"])'), + ("unknown", "Column", False, None), + ], +) +def test_sql_to_kql_aggregate(sql_aggregate, column_name, is_distinct, expected_kql): + assert ( + KustoKqlCompiler._sql_to_kql_aggregate(sql_aggregate, column_name, is_distinct) + == expected_kql + ) + + +def test_use_table(): metadata = MetaData() stream = Table( "logs", @@ -88,12 +285,12 @@ def test_use_table() -> None: query_compiled = str(query.compile(engine)).replace("\n", "") query_expected = ( - '["logs"]' "| project Field1, Field2" "| take __[POSTCOMPILE_param_1]" + '["logs"]' '| project ["Field1"], ["Field2"]| take __[POSTCOMPILE_param_1]' ) assert query_compiled == query_expected -def test_limit() -> None: +def test_limit(): sql = "logs" limit = 5 query = ( @@ -101,17 +298,14 @@ def test_limit() -> None: .select_from(TextAsFrom(text(sql), ["*"]).alias("inner_qry")) .limit(limit) ) - query_compiled = str( query.compile(engine, compile_kwargs={"literal_binds": True}) ).replace("\n", "") - query_expected = 'let inner_qry = (["logs"]);' "inner_qry" "| take 5" - assert query_compiled == query_expected -def test_select_count() -> None: +def test_select_count(): kql_query = "logs" column_count = literal_column("count(*)").label("count") query = ( @@ -131,14 +325,16 @@ def test_select_count() -> None: 'let inner_qry = (["logs"]);' "inner_qry" "| where Field1 > 1 and Field2 < 2" - "| summarize count = count()" + '| summarize ["count"] = count() ' + '| project ["count"]' + '| order by ["count"] desc' "| take 5" ) assert query_compiled == query_expected -def test_select_with_let() -> None: +def test_select_with_let(): kql_query = "let x = 5; let y = 3; MyTable | where Field1 == x and Field2 == y" query = ( select("*") @@ -161,7 +357,7 @@ def test_select_with_let() -> None: assert query_compiled == query_expected -def test_quotes() -> None: +def test_quotes(): quote = engine.dialect.identifier_preparer.quote metadata = MetaData() stream = Table( @@ -199,20 +395,41 @@ def test_quotes() -> None: ) def test_schema_from_metadata( table_name: str, schema_name: str, expected_table_name: str -) -> None: +): metadata = MetaData(schema=schema_name) if schema_name else MetaData() stream = Table( table_name, metadata, ) query = stream.select().limit(5) - query_compiled = str(query.compile(engine)).replace("\n", "") - query_expected = f"{expected_table_name}| take __[POSTCOMPILE_param_1]" assert query_compiled == query_expected +@pytest.mark.parametrize( + ("column_name", "expected_aggregate"), + [ + ("AVG(Score)", 'avg(["Score"])'), + ('AVG("2014")', 'avg(["2014"])'), + ('sum("2014")', 'sum(["2014"])'), + ("SUM(scores)", 'sum(["scores"])'), + ('MIN("scores")', 'min(["scores"])'), + ('MIN(["scores"])', 'min(["scores"])'), + ("max(scores)", 'max(["scores"])'), + ("startofmonth(somedate)", None), + ("startofmonth(somedate)/time(1d)", None), + ], +) +def test_match_aggregates(column_name: str, expected_aggregate: str): + kql_agg = KustoKqlCompiler._extract_maybe_agg_column_parts(column_name) + if expected_aggregate: + assert kql_agg is not None + assert kql_agg == expected_aggregate + else: + assert kql_agg is None + + @pytest.mark.parametrize( ("query_table_name", "expected_table_name"), [ @@ -227,7 +444,7 @@ def test_schema_from_metadata( ('["table"]', '["table"]'), ], ) -def test_schema_from_query(query_table_name: str, expected_table_name: str) -> None: +def test_schema_from_query(query_table_name: str, expected_table_name: str): query = ( select("*") .select_from(TextAsFrom(text(query_table_name), ["*"]).alias("inner_qry"))