Skip to content

Commit

Permalink
Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 22, 2025
1 parent 6f6b5bf commit e3af638
Show file tree
Hide file tree
Showing 16 changed files with 468 additions and 20 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"
DATAFLOW_NODE_OFFSET_BY_CUSTOM_GRANULARITY_ID_PREFIX = "obcg"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down Expand Up @@ -106,6 +107,7 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):

TIME_SPINE_SOURCE = "time_spine_src"
SUB_QUERY = "subq"
CTE = "cte"
NODE_RESOLVER_SUB_QUERY = "nr_subq"

@property
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def visit_time_dimension_spec(self, time_dimension_spec: TimeDimensionSpec) -> C
else ""
)
+ (
f"{DUNDER}{time_dimension_spec.window_function.value.lower()}"
if time_dimension_spec.window_function
f"{DUNDER}{DUNDER.join([window_function.value.lower() for window_function in time_dimension_spec.window_functions])}"
if time_dimension_spec.window_functions
else ""
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class TimeDimensionSpec(DimensionSpec): # noqa: D101
# Used for semi-additive joins. Some more thought is needed, but this may be useful in InstanceSpec.
aggregation_state: Optional[AggregationState] = None

window_function: Optional[SqlWindowFunction] = None
window_functions: Tuple[SqlWindowFunction, ...] = ()

def __post_init__(self) -> None:
"""Ensure that exactly one time granularity or date part is set."""
Expand Down Expand Up @@ -126,7 +126,7 @@ def without_first_entity_link(self) -> TimeDimensionSpec: # noqa: D102
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

@property
Expand All @@ -137,7 +137,7 @@ def without_entity_links(self) -> TimeDimensionSpec: # noqa: D102
date_part=self.date_part,
entity_links=(),
aggregation_state=self.aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

@property
Expand Down Expand Up @@ -177,7 +177,7 @@ def with_grain(self, time_granularity: ExpandedTimeGranularity) -> TimeDimension
entity_links=self.entity_links,
time_granularity=time_granularity,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -189,7 +189,7 @@ def with_base_grain(self) -> TimeDimensionSpec: # noqa: D102
),
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDimensionSpec: # noqa: D102
Expand All @@ -199,17 +199,17 @@ def with_aggregation_state(self, aggregation_state: AggregationState) -> TimeDim
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

def with_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionSpec: # noqa: D102
def with_window_functions(self, window_functions: Tuple[SqlWindowFunction, ...]) -> TimeDimensionSpec: # noqa: D102
return TimeDimensionSpec(
element_name=self.element_name,
entity_links=self.entity_links,
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=window_function,
window_functions=window_functions,
)

def comparison_key(self, exclude_fields: Sequence[TimeDimensionSpecField] = ()) -> TimeDimensionSpecComparisonKey:
Expand Down Expand Up @@ -267,7 +267,7 @@ def with_entity_prefix(self, entity_prefix: EntityReference) -> TimeDimensionSpe
time_granularity=self.time_granularity,
date_part=self.date_part,
aggregation_state=self.aggregation_state,
window_function=self.window_function,
window_functions=self.window_functions,
)

@staticmethod
Expand Down
6 changes: 6 additions & 0 deletions metricflow-semantics/metricflow_semantics/sql/sql_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,12 @@ def matches(self, other: SqlExpressionNode) -> bool: # noqa: D102
def from_table_and_column_names(table_alias: str, column_name: str) -> SqlColumnReferenceExpression: # noqa: D102
return SqlColumnReferenceExpression.create(SqlColumnReference(table_alias=table_alias, column_name=column_name))

def with_new_table_alias(self, new_table_alias: str) -> SqlColumnReferenceExpression:
"""Returns a new column reference expression with the same column name but a new table alias."""
return SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=new_table_alias, column_name=self.col_ref.column_name
)


@dataclass(frozen=True, eq=False)
class SqlColumnAliasReferenceExpression(SqlExpressionNode):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_classes() -> None: # noqa: D103
time_granularity=ExpandedTimeGranularity(name='day', base_granularity=DAY),
date_part=None,
aggregation_state=None,
window_function=None,
window_functions=(),
)
"""
).rstrip()
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,7 +1900,9 @@ def _build_time_spine_node(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec,
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity_name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
Expand Down
13 changes: 13 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -126,6 +127,12 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -222,3 +229,9 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
return self._default_handler(node)
75 changes: 75 additions & 0 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""For a given custom grain, offset its base grain by the requested number of custom grain periods."""

offset_window: MetricTimeWindow
required_time_spine_specs: Sequence[TimeDimensionSpec]
time_spine_node: DataflowPlanNode

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()

@staticmethod
def create( # noqa: D102
time_spine_node: DataflowPlanNode,
offset_window: MetricTimeWindow,
required_time_spine_specs: Sequence[TimeDimensionSpec],
) -> OffsetByCustomGranularityNode:
return OffsetByCustomGranularityNode(
parent_nodes=(time_spine_node,),
time_spine_node=time_spine_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_OFFSET_BY_CUSTOM_GRANULARITY_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_offset_by_custom_granularity_node(self)

@property
def description(self) -> str: # noqa: D102
return """Offset Base Granularity By Custom Granularity Period(s)"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("offset_window", self.offset_window),
DisplayedProperty("required_time_spine_specs", self.required_time_spine_specs),
)

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.offset_window == self.offset_window
and other_node.required_time_spine_specs == self.required_time_spine_specs
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> OffsetByCustomGranularityNode:
assert len(new_parent_nodes) == 1
return OffsetByCustomGranularityNode(
parent_nodes=tuple(new_parent_nodes),
time_spine_node=new_parent_nodes[0],
offset_window=self.offset_window,
required_time_spine_specs=self.required_time_spine_specs,
)
6 changes: 6 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,6 +473,11 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,3 +473,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -363,3 +364,9 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
24 changes: 19 additions & 5 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.type_enums import DatePart
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand All @@ -12,6 +13,7 @@
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from typing_extensions import override

from metricflow.dataset.dataset_classes import DataSet
Expand Down Expand Up @@ -181,18 +183,30 @@ def instance_for_column_name(self, column_name: str) -> MdoInstance:
)

def instance_from_time_dimension_grain_and_date_part(
self, time_dimension_spec: TimeDimensionSpec
self, time_granularity_name: Optional[str] = None, date_part: Optional[DatePart] = None
) -> TimeDimensionInstance:
"""Find instance in dataset that matches the grain and date part of the given time dimension spec."""
"""Find instance in dataset that matches the given grain and date part."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity
and time_dimension_instance.spec.date_part == time_dimension_spec.date_part
time_dimension_instance.spec.time_granularity_name == time_granularity_name
and time_dimension_instance.spec.date_part == date_part
and not time_dimension_instance.spec.window_functions
):
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n"
f"Did not find a time dimension instance with grain '{time_granularity_name}' and date part {date_part}\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

def instance_from_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionInstance:
"""Find instance in dataset that matches the given window function."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if window_function in time_dimension_instance.spec.window_functions:
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with window function {window_function}.\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

Expand Down
7 changes: 7 additions & 0 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -205,3 +206,9 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ConvertToExecutionPlanResult:
raise NotImplementedError

@override
def visit_offset_by_custom_granularity_node(
self, node: OffsetByCustomGranularityNode
) -> ConvertToExecutionPlanResult:
raise NotImplementedError
Loading

0 comments on commit e3af638

Please sign in to comment.