From cd61ca85de767ab5973ecf8b5e2f2b32e6691356 Mon Sep 17 00:00:00 2001 From: Courtney Holcomb Date: Wed, 22 Jan 2025 10:57:15 -0800 Subject: [PATCH] fixup! Add new dataflow plan nodes for custom offset windows --- metricflow/plan_conversion/dataflow_to_sql.py | 49 ++++++++++++------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 12842b9fc8..b037fa6565 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -6,7 +6,6 @@ from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar from dbt_semantic_interfaces.enum_extension import assert_values_exhausted -from dbt_semantic_interfaces.naming.keywords import DUNDER from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType from dbt_semantic_interfaces.references import MetricModelReference, SemanticModelElementReference from dbt_semantic_interfaces.type_enums.aggregation_type import AggregationType @@ -2099,8 +2098,9 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit # Build columns that get start and end of the custom grain period. # Ex: FIRST_VALUE(ds) OVER (PARTITION BY fiscal_quarter ORDER BY ds) AS ds__fiscal_quarter__first_value - new_select_columns: Tuple[SqlSelectColumn, ...] = tuple() + new_select_columns: Tuple[SqlSelectColumn, ...] = () bounds_columns: Tuple[SqlSelectColumn, ...] = () + bounds_instances: Tuple[TimeDimensionInstance, ...] = () custom_column_expr = SqlColumnReferenceExpression.from_table_and_column_names( table_alias=time_spine_alias, column_name=custom_grain_column_name ) @@ -2108,6 +2108,10 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit table_alias=time_spine_alias, column_name=base_grain_column_name ) for window_func in (SqlWindowFunction.FIRST_VALUE, SqlWindowFunction.LAST_VALUE): + bounds_instance = custom_grain_instance.with_new_spec( + new_spec=custom_grain_instance.spec.with_window_functions((window_func,)), + column_association_resolver=self._column_association_resolver, + ) select_column = SqlSelectColumn( expr=SqlWindowFunctionExpression.create( sql_function=window_func, @@ -2115,10 +2119,9 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit partition_by_args=(custom_column_expr,), order_by_args=(SqlWindowOrderByArgument(base_column_expr),), ), - column_alias=self._column_association_resolver.resolve_spec( - custom_grain_instance.spec.with_window_function(window_func) - ).column_name, + column_alias=bounds_instance.associated_column.column_name, ) + bounds_instances += (bounds_instance,) bounds_columns += (select_column,) new_select_columns += (select_column,) @@ -2132,7 +2135,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit order_by_args=(SqlWindowOrderByArgument(base_column_expr),), ), column_alias=self._column_association_resolver.resolve_spec( - base_grain_instance.spec.with_window_function(SqlWindowFunction.ROW_NUMBER) + base_grain_instance.spec.with_window_functions((SqlWindowFunction.ROW_NUMBER,)) ).column_name, ) new_select_columns += (row_number_column,) @@ -2169,20 +2172,30 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit custom_grain_column = SqlSelectColumn.from_table_and_column_names( column_name=custom_grain_column_name, table_alias=unique_rows_alias ) - first_value_offset_column, last_value_offset_column = tuple( - SqlSelectColumn( - expr=SqlWindowFunctionExpression.create( - sql_function=SqlWindowFunction.LEAD, - sql_function_args=( - bounds_column.ref_with_new_table_alias(unique_rows_alias), - SqlIntegerExpression.create(node.offset_window.count), + offset_bounds_columns: Tuple[SqlSelectColumn, ...] = () + for i in range(len(bounds_columns)): + bounds_instance = bounds_instances[i] + bounds_column = bounds_columns[i] + offset_bounds_instance = bounds_instance.with_new_spec( + bounds_instance.spec.with_window_functions( + (bounds_instance.spec.window_functions + (SqlWindowFunction.LEAD,)) + ), + column_association_resolver=self._column_association_resolver, + ) + offset_bounds_columns += ( + SqlSelectColumn( + expr=SqlWindowFunctionExpression.create( + sql_function=SqlWindowFunction.LEAD, + sql_function_args=( + bounds_column.ref_with_new_table_alias(unique_rows_alias), + SqlIntegerExpression.create(node.offset_window.count), + ), + order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), ), - order_by_args=(SqlWindowOrderByArgument(custom_grain_column.expr),), + column_alias=offset_bounds_instance.associated_column.column_name, ), - column_alias=f"{bounds_column.column_alias}{DUNDER}offset", ) - for bounds_column in bounds_columns - ) + first_value_offset_column, last_value_offset_column = offset_bounds_columns offset_bounds_subquery_alias = self._next_unique_table_alias() offset_bounds_subquery = SqlSelectStatementNode.create( description="Offset Custom Granularity Bounds", @@ -2213,7 +2226,7 @@ def visit_offset_by_custom_granularity_node(self, node: OffsetByCustomGranularit ) # LEAD isn't quite accurate here, but this will differentiate the offset instance (and column) from the original one. offset_base_column_name = self._column_association_resolver.resolve_spec( - base_grain_instance.spec.with_window_function(SqlWindowFunction.LEAD) + base_grain_instance.spec.with_window_functions((SqlWindowFunction.LEAD,)) ).column_name offset_base_column = SqlSelectColumn( expr=SqlCaseExpression.create(