Skip to content

Commit

Permalink
Add new dataflow plan nodes for custom offset windows
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Jan 22, 2025
1 parent 6f6b5bf commit 53f6507
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 8 deletions.
2 changes: 2 additions & 0 deletions metricflow-semantics/metricflow_semantics/dag/id_prefix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class StaticIdPrefix(IdPrefix, Enum, metaclass=EnumMetaClassHelper):
DATAFLOW_NODE_JOIN_CONVERSION_EVENTS_PREFIX = "jce"
DATAFLOW_NODE_WINDOW_REAGGREGATION_ID_PREFIX = "wr"
DATAFLOW_NODE_ALIAS_SPECS_ID_PREFIX = "as"
DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX = "cgb"
DATAFLOW_NODE_OFFSET_BY_CUSTOM_GRANULARITY_ID_PREFIX = "obcg"

SQL_EXPR_COLUMN_REFERENCE_ID_PREFIX = "cr"
SQL_EXPR_COMPARISON_ID_PREFIX = "cmp"
Expand Down
4 changes: 3 additions & 1 deletion metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,7 +1900,9 @@ def _build_time_spine_node(
parent_node=read_node,
change_specs=tuple(
SpecToAlias(
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(required_spec).spec,
input_spec=time_spine_data_set.instance_from_time_dimension_grain_and_date_part(
time_granularity_name=required_spec.time_granularity_name, date_part=required_spec.date_part
).spec,
output_spec=required_spec,
)
for required_spec in required_time_spine_specs
Expand Down
22 changes: 22 additions & 0 deletions metricflow/dataflow/dataflow_plan_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -23,6 +24,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -126,6 +128,16 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
raise NotImplementedError


class DataflowPlanNodeVisitorWithDefaultHandler(DataflowPlanNodeVisitor[VisitorOutputT], Generic[VisitorOutputT]):
"""Similar to `DataflowPlanNodeVisitor`, but with an abstract default handler that gets called for each node.
Expand Down Expand Up @@ -222,3 +234,13 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
@override
def visit_alias_specs_node(self, node: AliasSpecsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_custom_granularity_bounds_node(self, node: CustomGranularityBoundsNode) -> VisitorOutputT: # noqa: D102
return self._default_handler(node)

@override
def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> VisitorOutputT:
return self._default_handler(node)
64 changes: 64 additions & 0 deletions metricflow/dataflow/nodes/custom_granularity_bounds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor


@dataclass(frozen=True, eq=False)
class CustomGranularityBoundsNode(DataflowPlanNode, ABC):
"""Calculate the start and end of a custom granularity period and each row number within that period."""

custom_granularity_name: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create( # noqa: D102
parent_node: DataflowPlanNode, custom_granularity_name: str
) -> CustomGranularityBoundsNode:
return CustomGranularityBoundsNode(parent_nodes=(parent_node,), custom_granularity_name=custom_granularity_name)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_CUSTOM_GRANULARITY_BOUNDS_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_custom_granularity_bounds_node(self)

@property
def description(self) -> str: # noqa: D102
return """Calculate Custom Granularity Bounds"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("custom_granularity_name", self.custom_granularity_name),
)

@property
def parent_node(self) -> DataflowPlanNode: # noqa: D102
return self.parent_nodes[0]

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.custom_granularity_name == self.custom_granularity_name
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> CustomGranularityBoundsNode:
assert len(new_parent_nodes) == 1
return CustomGranularityBoundsNode.create(
parent_node=new_parent_nodes[0], custom_granularity_name=self.custom_granularity_name
)
95 changes: 95 additions & 0 deletions metricflow/dataflow/nodes/offset_by_custom_granularity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from __future__ import annotations

from abc import ABC
from dataclasses import dataclass
from typing import Optional, Sequence

from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DisplayedProperty
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.dataflow.dataflow_plan import DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitor
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode


@dataclass(frozen=True, eq=False)
class OffsetByCustomGranularityNode(DataflowPlanNode, ABC):
"""For a given custom grain, offset its base grain by the requested number of custom grain periods.
Only accepts CustomGranularityBoundsNode as parent node.
"""

offset_window: MetricTimeWindow
required_time_spine_specs: Sequence[TimeDimensionSpec]
custom_granularity_bounds_node: CustomGranularityBoundsNode
filter_elements_node: FilterElementsNode

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()

@staticmethod
def create( # noqa: D102
custom_granularity_bounds_node: CustomGranularityBoundsNode,
filter_elements_node: FilterElementsNode,
offset_window: MetricTimeWindow,
required_time_spine_specs: Sequence[TimeDimensionSpec],
) -> OffsetByCustomGranularityNode:
return OffsetByCustomGranularityNode(
parent_nodes=(custom_granularity_bounds_node, filter_elements_node),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
offset_window=offset_window,
required_time_spine_specs=required_time_spine_specs,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.DATAFLOW_NODE_OFFSET_BY_CUSTOM_GRANULARITY_ID_PREFIX

def accept(self, visitor: DataflowPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_offset_by_custom_granularity_node(self)

@property
def description(self) -> str: # noqa: D102
return """Offset Base Granularity By Custom Granularity Period(s)"""

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
return tuple(super().displayed_properties) + (
DisplayedProperty("offset_window", self.offset_window),
DisplayedProperty("required_time_spine_specs", self.required_time_spine_specs),
)

def functionally_identical(self, other_node: DataflowPlanNode) -> bool: # noqa: D102
return (
isinstance(other_node, self.__class__)
and other_node.offset_window == self.offset_window
and other_node.required_time_spine_specs == self.required_time_spine_specs
)

def with_new_parents( # noqa: D102
self, new_parent_nodes: Sequence[DataflowPlanNode]
) -> OffsetByCustomGranularityNode:
custom_granularity_bounds_node: Optional[CustomGranularityBoundsNode] = None
filter_elements_node: Optional[FilterElementsNode] = None
for parent_node in new_parent_nodes:
if isinstance(parent_node, CustomGranularityBoundsNode):
custom_granularity_bounds_node = parent_node
elif isinstance(parent_node, FilterElementsNode):
filter_elements_node = parent_node
assert custom_granularity_bounds_node and filter_elements_node, (
"Can't rewrite OffsetByCustomGranularityNode because the node requires a CustomGranularityBoundsNode and a "
f"FilterElementsNode as parents. Instead, got: {new_parent_nodes}"
)

return OffsetByCustomGranularityNode(
parent_nodes=tuple(new_parent_nodes),
custom_granularity_bounds_node=custom_granularity_bounds_node,
filter_elements_node=filter_elements_node,
offset_window=self.offset_window,
required_time_spine_specs=self.required_time_spine_specs,
)
12 changes: 12 additions & 0 deletions metricflow/dataflow/optimizer/predicate_pushdown_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -31,6 +32,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,6 +474,16 @@ def visit_join_to_custom_granularity_node( # noqa: D102
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
raise NotImplementedError

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
raise NotImplementedError

def visit_join_on_entities_node(self, node: JoinOnEntitiesNode) -> OptimizeBranchResult:
"""Handles pushdown state propagation for the standard join node type.
Expand Down
14 changes: 14 additions & 0 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -25,6 +26,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -472,3 +474,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> ComputeMetricsBranchCombinerRe
def visit_alias_specs_node(self, node: AliasSpecsNode) -> ComputeMetricsBranchCombinerResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> ComputeMetricsBranchCombinerResult:
self._log_visit_node_type(node)
return self._default_handler(node)
14 changes: 14 additions & 0 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from metricflow.dataflow.nodes.combine_aggregated_outputs import CombineAggregatedOutputsNode
from metricflow.dataflow.nodes.compute_metrics import ComputeMetricsNode
from metricflow.dataflow.nodes.constrain_time import ConstrainTimeRangeNode
from metricflow.dataflow.nodes.custom_granularity_bounds import CustomGranularityBoundsNode
from metricflow.dataflow.nodes.filter_elements import FilterElementsNode
from metricflow.dataflow.nodes.join_conversion_events import JoinConversionEventsNode
from metricflow.dataflow.nodes.join_over_time import JoinOverTimeRangeNode
Expand All @@ -27,6 +28,7 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataflow.nodes.metric_time_transform import MetricTimeDimensionTransformNode
from metricflow.dataflow.nodes.min_max import MinMaxNode
from metricflow.dataflow.nodes.offset_by_custom_granularity import OffsetByCustomGranularityNode
from metricflow.dataflow.nodes.order_by_limit import OrderByLimitNode
from metricflow.dataflow.nodes.read_sql_source import ReadSqlSourceNode
from metricflow.dataflow.nodes.semi_additive_join import SemiAdditiveJoinNode
Expand Down Expand Up @@ -363,3 +365,15 @@ def visit_min_max_node(self, node: MinMaxNode) -> OptimizeBranchResult: # noqa:
def visit_alias_specs_node(self, node: AliasSpecsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_custom_granularity_bounds_node( # noqa: D102
self, node: CustomGranularityBoundsNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)

def visit_offset_by_custom_granularity_node( # noqa: D102
self, node: OffsetByCustomGranularityNode
) -> OptimizeBranchResult:
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
24 changes: 19 additions & 5 deletions metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List, Optional, Sequence, Tuple

from dbt_semantic_interfaces.references import SemanticModelReference
from dbt_semantic_interfaces.type_enums import DatePart
from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set
from metricflow_semantics.instances import EntityInstance, InstanceSet, MdoInstance, TimeDimensionInstance
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand All @@ -12,6 +13,7 @@
from metricflow_semantics.specs.entity_spec import EntitySpec
from metricflow_semantics.specs.instance_spec import InstanceSpec
from metricflow_semantics.specs.time_dimension_spec import TimeDimensionSpec
from metricflow_semantics.sql.sql_exprs import SqlWindowFunction
from typing_extensions import override

from metricflow.dataset.dataset_classes import DataSet
Expand Down Expand Up @@ -181,18 +183,30 @@ def instance_for_column_name(self, column_name: str) -> MdoInstance:
)

def instance_from_time_dimension_grain_and_date_part(
self, time_dimension_spec: TimeDimensionSpec
self, time_granularity_name: Optional[str] = None, date_part: Optional[DatePart] = None
) -> TimeDimensionInstance:
"""Find instance in dataset that matches the grain and date part of the given time dimension spec."""
"""Find instance in dataset that matches the given grain and date part."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if (
time_dimension_instance.spec.time_granularity == time_dimension_spec.time_granularity
and time_dimension_instance.spec.date_part == time_dimension_spec.date_part
time_dimension_instance.spec.time_granularity_name == time_granularity_name
and time_dimension_instance.spec.date_part == date_part
and time_dimension_instance.spec.window_function is None
):
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with matching grain and date part for spec: {time_dimension_spec}\n"
f"Did not find a time dimension instance with grain '{time_granularity_name}' and date part {date_part}\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

def instance_from_window_function(self, window_function: SqlWindowFunction) -> TimeDimensionInstance:
"""Find instance in dataset that matches the given window function."""
for time_dimension_instance in self.instance_set.time_dimension_instances:
if time_dimension_instance.spec.window_function is window_function:
return time_dimension_instance

raise RuntimeError(
f"Did not find a time dimension instance with window function {window_function}.\n"
f"Instances available: {self.instance_set.time_dimension_instances}"
)

Expand Down
Loading

0 comments on commit 53f6507

Please sign in to comment.