diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index e45e4df5cbc15..9a2ad891d2e1b 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -1,3 +1,4 @@ +import asyncio import collections.abc import contextlib import dataclasses @@ -7,6 +8,7 @@ import psycopg import pyarrow as pa +import structlog from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -26,8 +28,9 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - iter_model_records, + raise_on_produce_task_failure, start_batch_export_run, + start_produce_batch_export_record_batches, ) from posthog.temporal.batch_exports.metrics import get_rows_exported_metric from posthog.temporal.batch_exports.postgres_batch_export import ( @@ -36,10 +39,15 @@ PostgreSQLClient, PostgreSQLField, ) -from posthog.temporal.batch_exports.utils import JsonType, apeek_first_and_rewind, set_status_to_running_task +from posthog.temporal.batch_exports.utils import ( + JsonType, + apeek_first_and_rewind, + set_status_to_running_task, +) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater -from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.common.logger import configure_temporal_worker_logger +from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat def remove_escaped_whitespace_recursive(value): @@ -221,6 +229,9 @@ def get_redshift_fields_from_record_schema( pg_schema: list[PostgreSQLField] = [] for name in record_schema.names: + if name == "_inserted_at": + continue + pa_field = record_schema.field(name) if pa.types.is_string(pa_field.type) or isinstance(pa_field.type, JsonType): @@ -261,11 +272,19 @@ def get_redshift_fields_from_record_schema( return pg_schema +@dataclasses.dataclass +class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails): + """The Redshift batch export details included in every heartbeat.""" + + pass + + async def insert_records_to_redshift( - records: collections.abc.AsyncGenerator[dict[str, typing.Any], None], + records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None], redshift_client: RedshiftClient, schema: str | None, table: str, + heartbeater: Heartbeater, batch_size: int = 100, use_super: bool = False, known_super_columns: list[str] | None = None, @@ -289,10 +308,11 @@ async def insert_records_to_redshift( make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low can significantly affect performance due to Redshift's poor handling of INSERTs. """ - first_record_batch, records_iterator = await apeek_first_and_rewind(records) - if first_record_batch is None: + first_value, records_iterator = await apeek_first_and_rewind(records) + if first_value is None: return 0 + first_record_batch, _inserted_at = first_value columns = first_record_batch.keys() if schema: @@ -332,7 +352,7 @@ async def flush_to_redshift(batch): # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. - async for record in records_iterator: + async for record, _inserted_at in records_iterator: for column in columns: if known_super_columns is not None and column in known_super_columns: record[column] = json.dumps(record[column], ensure_ascii=False) @@ -342,10 +362,12 @@ async def flush_to_redshift(batch): continue await flush_to_redshift(batch) + heartbeater.details = (str(_inserted_at),) batch = [] if len(batch) > 0: await flush_to_redshift(batch) + heartbeater.details = (str(_inserted_at),) return total_rows_exported @@ -378,7 +400,9 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records the Redshift-specific properties_data_type to indicate the type of JSON-like fields. """ - logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="Redshift") + logger = await configure_temporal_worker_logger( + logger=structlog.get_logger(), team_id=inputs.team_id, destination="Redshift" + ) await logger.ainfo( "Batch exporting range %s - %s to Redshift: %s.%s.%s", inputs.data_interval_start or "START", @@ -389,35 +413,72 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ) async with ( - Heartbeater(), + Heartbeater() as heartbeater, set_status_to_running_task(run_id=inputs.run_id, logger=logger), get_client(team_id=inputs.team_id) as client, ): if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") + should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger) + + if should_resume is True and details is not None: + data_interval_start: str | None = details.last_inserted_at.isoformat() + else: + data_interval_start = inputs.data_interval_start + model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { field.name for field in dataclasses.fields(inputs) }: model = inputs.batch_export_model - + if model is not None: + model_name = model.name + extra_query_parameters = model.schema["values"] if model.schema is not None else None + fields = model.schema["fields"] if model.schema is not None else None + else: + model_name = "events" + extra_query_parameters = None + fields = None else: model = inputs.batch_export_schema + model_name = "custom" + extra_query_parameters = model["values"] if model is not None else {} + fields = model["fields"] if model is not None else None - record_iterator = iter_model_records( + queue, produce_task = start_produce_batch_export_record_batches( client=client, - model=model, + model_name=model_name, + is_backfill=inputs.is_backfill, team_id=inputs.team_id, - interval_start=inputs.data_interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, + fields=fields, destination_default_fields=redshift_default_fields(), - is_backfill=inputs.is_backfill, + extra_query_parameters=extra_query_parameters, ) - first_record_batch, record_iterator = await apeek_first_and_rewind(record_iterator) - if first_record_batch is None: + + get_schema_task = asyncio.create_task(queue.get_schema()) + await asyncio.wait( + [get_schema_task, produce_task], + return_when=asyncio.FIRST_COMPLETED, + ) + + # Finishing producing happens sequentially after putting to queue and setting the schema. + # So, either we finished producing and setting the schema tasks, or we finished without + # putting anything in the queue. + if get_schema_task.done(): + # In the first case, we'll land here. + # The schema is available, and the queue is not empty, so we can start the batch export. + record_batch_schema = get_schema_task.result() + else: + # In the second case, we'll land here: We finished producing without putting anything. + # Since we finished producing with an empty queue, there is nothing to batch export. + # We could have also failed, so we need to re-raise that exception to allow a retry if + # that's the case. + await raise_on_produce_task_failure(produce_task) return 0 known_super_columns = ["properties", "set", "set_once", "person_properties"] @@ -442,10 +503,8 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ("timestamp", "TIMESTAMP WITH TIME ZONE"), ] else: - column_names = [column for column in first_record_batch.schema.names if column != "_inserted_at"] - record_schema = first_record_batch.select(column_names).schema table_fields = get_redshift_fields_from_record_schema( - record_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER" + record_batch_schema, known_super_columns=known_super_columns, use_super=properties_type == "SUPER" ) requires_merge = ( @@ -477,7 +536,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ): schema_columns = {field[0] for field in table_fields} - def map_to_record(row: dict) -> dict: + def map_to_record(row: dict) -> tuple[dict, dt.datetime]: """Map row to a record to insert to Redshift.""" record = {k: v for k, v in row.items() if k in schema_columns} @@ -486,10 +545,24 @@ def map_to_record(row: dict) -> dict: # TODO: We should be able to save a json.loads here. record[column] = remove_escaped_whitespace_recursive(json.loads(record[column])) - return record + return record, row["_inserted_at"] + + async def record_generator() -> ( + collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None] + ): + while not queue.empty() or not produce_task.done(): + try: + record_batch = queue.get_nowait() + except asyncio.QueueEmpty: + if produce_task.done(): + await logger.adebug( + "Empty queue with no more events being produced, closing consumer loop" + ) + return + else: + await asyncio.sleep(0.1) + continue - async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing.Any], None]: - async for record_batch in record_iterator: for record in record_batch.to_pylist(): yield map_to_record(record) @@ -498,6 +571,7 @@ async def record_generator() -> collections.abc.AsyncGenerator[dict[str, typing. redshift_client, inputs.schema, redshift_stage_table if requires_merge else redshift_table, + heartbeater=heartbeater, use_super=properties_type == "SUPER", known_super_columns=known_super_columns, )