Skip to content

Commit

Permalink
refactor(batch-exports): Use async producer in Redshift export (#25872)
Browse files Browse the repository at this point in the history
Co-authored-by: Ross <[email protected]>
  • Loading branch information
tomasfarias and rossgray authored Nov 1, 2024
1 parent 4722eec commit 0f3fb72
Showing 1 changed file with 97 additions and 23 deletions.
120 changes: 97 additions & 23 deletions posthog/temporal/batch_exports/redshift_batch_export.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import collections.abc
import contextlib
import dataclasses
Expand All @@ -7,6 +8,7 @@

import psycopg
import pyarrow as pa
import structlog
from psycopg import sql
from temporalio import activity, workflow
from temporalio.common import RetryPolicy
Expand All @@ -26,8 +28,9 @@
default_fields,
execute_batch_export_insert_activity,
get_data_interval,
iter_model_records,
raise_on_produce_task_failure,
start_batch_export_run,
start_produce_batch_export_record_batches,
)
from posthog.temporal.batch_exports.metrics import get_rows_exported_metric
from posthog.temporal.batch_exports.postgres_batch_export import (
Expand All @@ -36,10 +39,15 @@
PostgreSQLClient,
PostgreSQLField,
)
from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task
from posthog.temporal.batch_exports.utils import (
JsonType,
apeek_first_and_rewind,
set_status_to_running_task,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.heartbeat import Heartbeater
from posthog.temporal.common.logger import bind_temporal_worker_logger
from posthog.temporal.common.logger import configure_temporal_worker_logger
from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat


def remove_escaped_whitespace_recursive(value):
Expand Down Expand Up @@ -221,6 +229,9 @@ def get_redshift_fields_from_record_schema(
pg_schema: list[PostgreSQLField] = []

for name in record_schema.names:
if name == "_inserted_at":
continue

pa_field = record_schema.field(name)

if pa.types.is_string(pa_field.type) or isinstance(pa_field.type, JsonType):
Expand Down Expand Up @@ -261,11 +272,19 @@ def get_redshift_fields_from_record_schema(
return pg_schema


@dataclasses.dataclass
class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails):
"""The Redshift batch export details included in every heartbeat."""

pass


async def insert_records_to_redshift(
records: collections.abc.AsyncGenerator[dict[str, typing.Any], None],
records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None],
redshift_client: RedshiftClient,
schema: str | None,
table: str,
heartbeater: Heartbeater,
batch_size: int = 100,
use_super: bool = False,
known_super_columns: list[str] | None = None,
Expand All @@ -289,10 +308,11 @@ async def insert_records_to_redshift(
make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low
can significantly affect performance due to Redshift's poor handling of INSERTs.
"""
first_record_batch, records_iterator = await apeek_first_and_rewind(records)
if first_record_batch is None:
first_value, records_iterator = await apeek_first_and_rewind(records)
if first_value is None:
return 0

first_record_batch, _inserted_at = first_value
columns = first_record_batch.keys()

if schema:
Expand Down Expand Up @@ -332,7 +352,7 @@ async def flush_to_redshift(batch):
# the byte size of each batch the way things are currently written. We can revisit this
# in the future if we decide it's useful enough.

async for record in records_iterator:
async for record, _inserted_at in records_iterator:
for column in columns:
if known_super_columns is not None and column in known_super_columns:
record[column] = json.dumps(record[column], ensure_ascii=False)
Expand All @@ -342,10 +362,12 @@ async def flush_to_redshift(batch):
continue

await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)
batch = []

if len(batch) > 0:
await flush_to_redshift(batch)
heartbeater.details = (str(_inserted_at),)

return total_rows_exported

Expand Down Expand Up @@ -378,7 +400,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
the Redshift-specific properties_data_type to indicate the type of JSON-like
fields.
"""
logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Redshift")
logger = await configure_temporal_worker_logger(
logger=structlog.get_logger(), team_id=inputs.team_id, destination="Redshift"
)
await logger.ainfo(
"Batch exporting range %s - %s to Redshift: %s.%s.%s",
inputs.data_interval_start or "START",
Expand All @@ -389,35 +413,72 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
)

async with (
Heartbeater(),
Heartbeater() as heartbeater,
set_status_to_running_task(run_id=inputs.run_id, logger=logger),
get_client(team_id=inputs.team_id) as client,
):
if not await client.is_alive():
raise ConnectionError("Cannot establish connection to ClickHouse")

should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger)

if should_resume is True and details is not None:
data_interval_start: str | None = details.last_inserted_at.isoformat()
else:
data_interval_start = inputs.data_interval_start

model: BatchExportModel | BatchExportSchema | None = None
if inputs.batch_export_schema is None and "batch_export_model" in {
field.name for field in dataclasses.fields(inputs)
}:
model = inputs.batch_export_model

if model is not None:
model_name = model.name
extra_query_parameters = model.schema["values"] if model.schema is not None else None
fields = model.schema["fields"] if model.schema is not None else None
else:
model_name = "events"
extra_query_parameters = None
fields = None
else:
model = inputs.batch_export_schema
model_name = "custom"
extra_query_parameters = model["values"] if model is not None else {}
fields = model["fields"] if model is not None else None

record_iterator = iter_model_records(
queue, produce_task = start_produce_batch_export_record_batches(
client=client,
model=model,
model_name=model_name,
is_backfill=inputs.is_backfill,
team_id=inputs.team_id,
interval_start=inputs.data_interval_start,
interval_start=data_interval_start,
interval_end=inputs.data_interval_end,
exclude_events=inputs.exclude_events,
include_events=inputs.include_events,
fields=fields,
destination_default_fields=redshift_default_fields(),
is_backfill=inputs.is_backfill,
extra_query_parameters=extra_query_parameters,
)
first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator)
if first_record_batch is None:

get_schema_task = asyncio.create_task(queue.get_schema())
await asyncio.wait(
[get_schema_task, produce_task],
return_when=asyncio.FIRST_COMPLETED,
)

# Finishing producing happens sequentially after putting to queue and setting the schema.
# So, either we finished producing and setting the schema tasks, or we finished without
# putting anything in the queue.
if get_schema_task.done():
# In the first case, we'll land here.
# The schema is available, and the queue is not empty, so we can start the batch export.
record_batch_schema = get_schema_task.result()
else:
# In the second case, we'll land here: We finished producing without putting anything.
# Since we finished producing with an empty queue, there is nothing to batch export.
# We could have also failed, so we need to re-raise that exception to allow a retry if
# that's the case.
await raise_on_produce_task_failure(produce_task)
return 0

known_super_columns = ["properties", "set", "set_once", "person_properties"]
Expand All @@ -442,10 +503,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
("timestamp", "TIMESTAMP WITH TIME ZONE"),
]
else:
column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"]
record_schema = first_record_batch.select(column_names).schema
table_fields = get_redshift_fields_from_record_schema(
record_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER"
record_batch_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER"
)

requires_merge = (
Expand Down Expand Up @@ -477,7 +536,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records
):
schema_columns = {field[0] for field in table_fields}

def map_to_record(row: dict) -> dict:
def map_to_record(row: dict) -> tuple[dict, dt.datetime]:
"""Map row to a record to insert to Redshift."""
record = {k: v for k, v in row.items() if k in schema_columns}

Expand All @@ -486,10 +545,24 @@ def map_to_record(row: dict) -> dict:
# TODO: We should be able to save a json.loads here.
record[column] = remove_escaped_whitespace_recursive(json.loads(record[column]))

return record
return record, row["_inserted_at"]

async def record_generator() -> (
collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None]
):
while not queue.empty() or not produce_task.done():
try:
record_batch = queue.get_nowait()
except asyncio.QueueEmpty:
if produce_task.done():
await logger.adebug(
"Empty queue with no more events being produced, closing consumer loop"
)
return
else:
await asyncio.sleep(0.1)
continue

async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]:
async for record_batch in record_iterator:
for record in record_batch.to_pylist():
yield map_to_record(record)

Expand All @@ -498,6 +571,7 @@ async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.
redshift_client,
inputs.schema,
redshift_stage_table if requires_merge else redshift_table,
heartbeater=heartbeater,
use_super=properties_type == "SUPER",
known_super_columns=known_super_columns,
)
Expand Down

0 comments on commit 0f3fb72

Please sign in to comment.