Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EventTimeFilter and BaseRelation.render_event_time_filtered #285

Merged
merged 10 commits into from
Sep 10, 2024
7 changes: 7 additions & 0 deletions .changes/unreleased/Features-20240905-180956.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Features
body: Add EventTimeFilter to BaseRelation, which renders a filtered relation when
start or end is set
time: 2024-09-05T18:09:56.159385-04:00
custom:
Author: 'michelleark QMalcolm'
Issue: "294"
54 changes: 51 additions & 3 deletions dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -36,6 +37,13 @@
SerializableIterable = Union[Tuple, FrozenSet]


@dataclass(frozen=True, eq=False, repr=False)
class EventTimeFilter(FakeAPIObject, Hashable):
field_name: str
start: Optional[datetime] = None
colin-rogers-dbt marked this conversation as resolved.
Show resolved Hide resolved
end: Optional[datetime] = None


@dataclass(frozen=True, eq=False, repr=False)
class BaseRelation(FakeAPIObject, Hashable):
path: Path
Expand All @@ -47,6 +55,7 @@ class BaseRelation(FakeAPIObject, Hashable):
quote_policy: Policy = field(default_factory=lambda: Policy())
dbt_created: bool = False
limit: Optional[int] = None
event_time_filter: Optional[EventTimeFilter] = None
require_alias: bool = (
True # used to govern whether to add an alias when render_limited is called
)
Expand Down Expand Up @@ -208,14 +217,19 @@ def render(self) -> str:
# if there is nothing set, this will return the empty string.
return ".".join(part for _, part in self._render_iterator() if part is not None)

def _render_limited_alias(self) -> str:
def _render_subquery_alias(self, namespace: str) -> str:
"""Some databases require an alias for subqueries (postgres, mysql) for all others we want to avoid adding
an alias as it has the potential to introduce issues with the query if the user also defines an alias.
"""
if self.require_alias:
return f" _dbt_limit_subq_{self.table}"
return f" _dbt_{namespace}_subq_{self.table}"
return ""

def _render_limited_alias(
self,
) -> str:
return self._render_subquery_alias(namespace="limit")

def render_limited(self) -> str:
rendered = self.render()
if self.limit is None:
Expand All @@ -225,6 +239,31 @@ def render_limited(self) -> str:
else:
return f"(select * from {rendered} limit {self.limit}){self._render_limited_alias()}"

def render_event_time_filtered(self, rendered: Optional[str] = None) -> str:
rendered = rendered or self.render()
if self.event_time_filter is None:
return rendered

filter = self._render_event_time_filtered(self.event_time_filter)
if not filter:
return rendered

return f"(select * from {rendered} where {filter}){self._render_subquery_alias(namespace='et_filter')}"

def _render_event_time_filtered(self, event_time_filter: EventTimeFilter) -> str:
"""
Returns "" if start and end are both None
"""
filter = ""
if event_time_filter.start and event_time_filter.end:
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}' and {event_time_filter.field_name} < '{event_time_filter.end}'"
elif event_time_filter.start:
filter = f"{event_time_filter.field_name} >= '{event_time_filter.start}'"
elif event_time_filter.end:
filter = f"{event_time_filter.field_name} < '{event_time_filter.end}'"
colin-rogers-dbt marked this conversation as resolved.
Show resolved Hide resolved

return filter

def quoted(self, identifier):
return "{quote_char}{identifier}{quote_char}".format(
quote_char=self.quote_character,
Expand All @@ -240,6 +279,7 @@ def create_ephemeral_from(
cls: Type[Self],
relation_config: RelationConfig,
limit: Optional[int] = None,
event_time_filter: Optional[EventTimeFilter] = None,
MichelleArk marked this conversation as resolved.
Show resolved Hide resolved
) -> Self:
# Note that ephemeral models are based on the identifier, which will
# point to the model's alias if one exists and otherwise fall back to
Expand All @@ -250,6 +290,7 @@ def create_ephemeral_from(
type=cls.CTE,
identifier=identifier,
limit=limit,
event_time_filter=event_time_filter,
).quote(identifier=False)

@classmethod
Expand Down Expand Up @@ -315,7 +356,14 @@ def __hash__(self) -> int:
return hash(self.render())

def __str__(self) -> str:
return self.render() if self.limit is None else self.render_limited()
rendered = self.render() if self.limit is None else self.render_limited()

# Limited subquery is wrapped by the event time filter subquery, and not the other way around.
# This is because in the context of resolving limited refs, we care more about performance than reliably producing a sample of a certain size.
if self.event_time_filter:
rendered = self.render_event_time_filtered(rendered)

return rendered

@property
def database(self) -> Optional[str]:
Expand Down
77 changes: 76 additions & 1 deletion tests/unit/test_relation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dataclasses import dataclass, replace

from datetime import datetime
import pytest

from dbt.adapters.base import BaseRelation
from dbt.adapters.base.relation import EventTimeFilter
from dbt.adapters.contracts.relation import RelationType


Expand Down Expand Up @@ -81,6 +82,80 @@ def test_render_limited(limit, require_alias, expected_result):
assert str(my_relation) == expected_result


@pytest.mark.parametrize(
"event_time_filter,require_alias,expected_result",
[
(None, False, '"test_database"."test_schema"."test_identifier"'),
(
EventTimeFilter(field_name="column"),
False,
'"test_database"."test_schema"."test_identifier"',
),
(None, True, '"test_database"."test_schema"."test_identifier"'),
(
EventTimeFilter(field_name="column"),
True,
'"test_database"."test_schema"."test_identifier"',
),
(
EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00')""",
),
(
EventTimeFilter(field_name="column", start=datetime(year=2020, month=1, day=1)),
True,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00') _dbt_et_filter_subq_test_identifier""",
),
(
EventTimeFilter(field_name="column", end=datetime(year=2020, month=1, day=1)),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column < '2020-01-01 00:00:00')""",
),
(
EventTimeFilter(
field_name="column",
start=datetime(year=2020, month=1, day=1),
end=datetime(year=2020, month=1, day=2),
),
False,
"""(select * from "test_database"."test_schema"."test_identifier" where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')""",
),
],
)
def test_render_event_time_filtered(event_time_filter, require_alias, expected_result):
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
event_time_filter=event_time_filter,
require_alias=require_alias,
)
actual_result = my_relation.render_event_time_filtered()
assert actual_result == expected_result
assert str(my_relation) == expected_result


def test_render_event_time_filtered_and_limited():
my_relation = BaseRelation.create(
database="test_database",
schema="test_schema",
identifier="test_identifier",
event_time_filter=EventTimeFilter(
field_name="column",
start=datetime(year=2020, month=1, day=1),
end=datetime(year=2020, month=1, day=2),
),
limit=0,
require_alias=False,
)
expected_result = """(select * from (select * from "test_database"."test_schema"."test_identifier" where false limit 0) where column >= '2020-01-01 00:00:00' and column < '2020-01-02 00:00:00')"""

actual_result = my_relation.render_event_time_filtered(my_relation.render_limited())
assert actual_result == expected_result
assert str(my_relation) == expected_result


def test_create_ephemeral_from_uses_identifier():
@dataclass
class Node:
Expand Down
Loading