Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(batch-exports): Use async producer in Redshift export #25872

Merged
merged 5 commits into from
Nov 1, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 97 additions & 23 deletions posthog/temporal/batch_exports/redshift_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections.abc
import contextlib
import dataclasses
Expand All @@ -7,6 +8,7 @@

import psycopg
import pyarrow as pa
import structlog
from psycopg import sql
from temporalio import activity, workflow
from temporalio.common import RetryPolicy
Expand All @@ -26,8 +28,9 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
iter_model_records,
raise_on_produce_task_failure,
start_batch_export_run,
start_produce_batch_export_record_batches,
)
from posthog.temporal.batch_exports.metrics import get_rows_exported_metric
from posthog.temporal.batch_exports.postgres_batch_export import (
Expand All @@ -36,10 +39,15 @@
PostgreSQLClient,
PostgreSQLField,
)
from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat


def remove_escaped_whitespace_recursive(value):
Expand Down Expand Up @@ -221,6 +229,9 @@ def get_redshift_fields_from_record_schema(
pg_schema: list[PostgreSQLField] = []

for name in record_schema.names:
if name == "_inserted_at":
continue

pa_field = record_schema.field(name)

if pa.types.is_string(pa_field.type) or isinstance(pa_field.type, JsonType):
Expand Down Expand Up @@ -261,11 +272,19 @@ def get_redshift_fields_from_record_schema(
return pg_schema


@dataclasses.dataclass
class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails):
"""The Redshift batch export details included in every heartbeat."""

pass


async def insert_records_to_redshift(
records: collections.abc.AsyncGenerator[dict[str, typing.Any], None],
records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None],
redshift_client: RedshiftClient,
schema: str | None,
table: str,
heartbeater: Heartbeater,
batch_size: int = 100,
use_super: bool = False,
known_super_columns: list[str] | None = None,
Expand All @@ -289,10 +308,11 @@ async def insert_records_to_redshift(
make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low
can significantly affect performance due to Redshift's poor handling of INSERTs.
"""
first_record_batch, records_iterator = await apeek_first_and_rewind(records)
if first_record_batch is None:
first_value, records_iterator = await apeek_first_and_rewind(records)
if first_value is None:
return 0

first_record_batch, _inserted_at = first_value
columns = first_record_batch.keys()

if schema:
Expand Down Expand Up @@ -332,7 +352,7 @@ async def flush_to_redshift(batch):
# the byte size of each batch the way things are currently written. We can revisit this
# in the future if we decide it's useful enough.

async for record in records_iterator:
async for record, _inserted_at in records_iterator:
for column in columns:
if known_super_columns is not None and column in known_super_columns:
record[column] = json.dumps(record[column], ensure_ascii=False)
Expand All @@ -342,10 +362,12 @@ async def flush_to_redshift(batch):
continue

await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)
batch = []

if len(batch) > 0:
await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)

return total_rows_exported

Expand Down Expand Up @@ -378,7 +400,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
the Redshift-specific properties_data_type to indicate the type of JSON-like
fields.
"""
logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Redshift")
logger = await configure_temporal_worker_logger(
logger=structlog.get_logger(), team_id=inputs.team_id, destination="Redshift"
)
await logger.ainfo(
"Batch exporting range %s - %s to Redshift: %s.%s.%s",
inputs.data_interval_start or "START",
Expand All @@ -389,35 +413,72 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
)

async with (
Heartbeater(),
Heartbeater() as heartbeater,
set_status_to_running_task(run_id=inputs.run_id, logger=logger),
get_client(team_id=inputs.team_id) as client,
):
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger)

if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
else:
data_interval_start = inputs.data_interval_start

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model

if model is not None:
model_name = model.name
extra_query_parameters = model.schema["values"] if model.schema is not None else None
fields = model.schema["fields"] if model.schema is not None else None
else:
model_name = "events"
extra_query_parameters = None
fields = None
else:
model = inputs.batch_export_schema
model_name = "custom"
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None

record_iterator = iter_model_records(
queue, produce_task = start_produce_batch_export_record_batches(
client=client,
model=model,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=inputs.data_interval_start,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
destination_default_fields=redshift_default_fields(),
is_backfill=inputs.is_backfill,
extra_query_parameters=extra_query_parameters,
)
first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)
if first_record_batch is None:

get_schema_task = asyncio.create_task(queue.get_schema())
await asyncio.wait(
[get_schema_task, produce_task],
return_when=asyncio.FIRST_COMPLETED,
)

# Finishing producing happens sequentially after putting to queue and setting the schema.
# So, either we finished producing and setting the schema tasks, or we finished without
# putting anything in the queue.
if get_schema_task.done():
# In the first case, we'll land here.
# The schema is available, and the queue is not empty, so we can start the batch export.
record_batch_schema = get_schema_task.result()
else:
# In the second case, we'll land here: We finished producing without putting anything.
# Since we finished producing with an empty queue, there is nothing to batch export.
# We could have also failed, so we need to re-raise that exception to allow a retry if
# that's the case.
await raise_on_produce_task_failure(produce_task)
return 0

known_super_columns = ["properties", "set", "set_once", "person_properties"]
Expand All @@ -442,10 +503,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
("timestamp", "TIMESTAMP WITH TIME ZONE"),
]
else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_redshift_fields_from_record_schema(
record_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER"
record_batch_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER"
)

requires_merge = (
Expand Down Expand Up @@ -477,7 +536,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
):
schema_columns = {field[0] for field in table_fields}

def map_to_record(row: dict) -> dict:
def map_to_record(row: dict) -> tuple[dict, dt.datetime]:
"""Map row to a record to insert to Redshift."""
record = {k: v for k, v in row.items() if k in schema_columns}

Expand All @@ -486,10 +545,24 @@ def map_to_record(row: dict) -> dict:
# TODO: We should be able to save a json.loads here.
record[column] = remove_escaped_whitespace_recursive(json.loads(record[column]))

return record
return record, row["_inserted_at"]

async def record_generator() -> (
collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None]
):
while not queue.empty() or not produce_task.done():
try:
record_batch = queue.get_nowait()
except asyncio.QueueEmpty:
if produce_task.done():
await logger.adebug(
"Empty queue with no more events being produced, closing consumer loop"
)
return
else:
await asyncio.sleep(0.1)
continue

async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]:
async for record_batch in record_iterator:
for record in record_batch.to_pylist():
yield map_to_record(record)

Expand All @@ -498,6 +571,7 @@ async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.
redshift_client,
inputs.schema,
redshift_stage_table if requires_merge else redshift_table,
heartbeater=heartbeater,
use_super=properties_type == "SUPER",
known_super_columns=known_super_columns,
)
Expand Down
Loading