From c53f9316f34be67036f2264702b8893740b2115c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 2 Jan 2025 15:30:01 +0100 Subject: [PATCH] refactor: Redshift batch export uses spmc consumer (#26897) --- posthog/settings/temporal.py | 4 + .../batch_exports/bigquery_batch_export.py | 32 +- .../batch_exports/redshift_batch_export.py | 313 ++++++------------ .../temporal/batch_exports/s3_batch_export.py | 25 +- posthog/temporal/batch_exports/spmc.py | 106 +++--- .../temporal/batch_exports/temporary_file.py | 154 +++++++++ posthog/temporal/common/clickhouse.py | 5 +- .../test_redshift_batch_export_workflow.py | 4 +- 8 files changed, 365 insertions(+), 278 deletions(-) diff --git a/posthog/settings/temporal.py b/posthog/settings/temporal.py index 33daed600cebf..e168d12c46d84 100644 --- a/posthog/settings/temporal.py +++ b/posthog/settings/temporal.py @@ -25,6 +25,10 @@ BATCH_EXPORT_BIGQUERY_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env( "BATCH_EXPORT_BIGQUERY_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 0, type_cast=int ) +BATCH_EXPORT_REDSHIFT_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 8 # 8MB +BATCH_EXPORT_REDSHIFT_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES: int = get_from_env( + "BATCH_EXPORT_REDSHIFT_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES", 1024 * 1024 * 300, type_cast=int +) BATCH_EXPORT_HTTP_UPLOAD_CHUNK_SIZE_BYTES: int = 1024 * 1024 * 50 # 50MB BATCH_EXPORT_HTTP_BATCH_SIZE: int = 5000 BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES: int = 1024 * 1024 * 300 # 300MB diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 30c600d210802..5aa3965b5a8bd 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -40,7 +40,7 @@ Consumer, Producer, RecordBatchQueue, - run_consumer_loop, + run_consumer, wait_for_schema_or_producer, ) from posthog.temporal.batch_exports.temporary_file import ( @@ -519,12 +519,19 @@ def __init__( heartbeater: Heartbeater, heartbeat_details: BigQueryHeartbeatDetails, data_interval_start: dt.datetime | str | None, + data_interval_end: dt.datetime | str, writer_format: WriterFormat, bigquery_client: BigQueryClient, bigquery_table: bigquery.Table, - table_schema: list[BatchExportField], + table_schema: list[bigquery.SchemaField], ): - super().__init__(heartbeater, heartbeat_details, data_interval_start, writer_format) + super().__init__( + heartbeater=heartbeater, + heartbeat_details=heartbeat_details, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + writer_format=writer_format, + ) self.bigquery_client = bigquery_client self.bigquery_table = bigquery_table self.table_schema = table_schema @@ -629,11 +636,10 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records include_events=inputs.include_events, extra_query_parameters=extra_query_parameters, ) - records_completed = 0 record_batch_schema = await wait_for_schema_or_producer(queue, producer_task) if record_batch_schema is None: - return records_completed + return 0 record_batch_schema = pa.schema( # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other @@ -700,21 +706,23 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records create=can_perform_merge, delete=can_perform_merge, ) as bigquery_stage_table: - records_completed = await run_consumer_loop( - queue=queue, - consumer_cls=BigQueryConsumer, - producer_task=producer_task, + consumer = BigQueryConsumer( heartbeater=heartbeater, heartbeat_details=details, data_interval_end=data_interval_end, data_interval_start=data_interval_start, - schema=record_batch_schema, writer_format=WriterFormat.PARQUET if can_perform_merge else WriterFormat.JSONL, - max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, - json_columns=() if can_perform_merge else json_columns, bigquery_client=bq_client, bigquery_table=bigquery_stage_table if can_perform_merge else bigquery_table, table_schema=stage_schema if can_perform_merge else schema, + ) + records_completed = await run_consumer( + consumer=consumer, + queue=queue, + producer_task=producer_task, + schema=record_batch_schema, + max_bytes=settings.BATCH_EXPORT_BIGQUERY_UPLOAD_CHUNK_SIZE_BYTES, + json_columns=() if can_perform_merge else json_columns, writer_file_kwargs={"compression": "zstd"} if can_perform_merge else {}, multiple_files=True, ) diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index 3b02efddb5a0b..827bfa684c143 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -1,5 +1,3 @@ -import asyncio -import collections.abc import contextlib import dataclasses import datetime as dt @@ -9,6 +7,7 @@ import psycopg import pyarrow as pa import structlog +from django.conf import settings from psycopg import sql from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -28,25 +27,29 @@ default_fields, execute_batch_export_insert_activity, get_data_interval, - raise_on_produce_task_failure, start_batch_export_run, - start_produce_batch_export_record_batches, ) from posthog.temporal.batch_exports.heartbeat import ( BatchExportRangeHeartbeatDetails, DateRange, should_resume_from_activity_heartbeat, ) -from posthog.temporal.batch_exports.metrics import get_rows_exported_metric from posthog.temporal.batch_exports.postgres_batch_export import ( Fields, PostgresInsertInputs, PostgreSQLClient, PostgreSQLField, ) +from posthog.temporal.batch_exports.spmc import ( + Consumer, + Producer, + RecordBatchQueue, + run_consumer, + wait_for_schema_or_producer, +) +from posthog.temporal.batch_exports.temporary_file import BatchExportTemporaryFile, WriterFormat from posthog.temporal.batch_exports.utils import ( JsonType, - apeek_first_and_rewind, set_status_to_running_task, ) from posthog.temporal.common.clickhouse import get_client @@ -54,43 +57,6 @@ from posthog.temporal.common.logger import configure_temporal_worker_logger -def remove_escaped_whitespace_recursive(value): - """Remove all escaped whitespace characters from given value. - - PostgreSQL supports constant escaped strings by appending an E' to each string that - contains whitespace in them (amongst other characters). See: - https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE - - However, Redshift does not support this syntax. So, to avoid any escaping by - underlying PostgreSQL library, we remove the whitespace ourselves as defined in the - translation table WHITESPACE_TRANSLATE. - - This function is recursive just to be extremely careful and catch any whitespace that - may be sneaked in a dictionary key or sequence. - """ - match value: - case str(s): - return " ".join(s.replace("\b", " ").split()) - - case bytes(b): - return remove_escaped_whitespace_recursive(b.decode("utf-8")) - - case [*sequence]: - # mypy could be bugged as it's raising a Statement unreachable error. - # But we are definitely reaching this statement in tests; hence the ignore comment. - # Maybe: https://github.com/python/mypy/issues/16272. - return type(value)(remove_escaped_whitespace_recursive(sequence_value) for sequence_value in sequence) # type: ignore - - case set(elements): - return {remove_escaped_whitespace_recursive(element) for element in elements} - - case {**mapping}: - return {k: remove_escaped_whitespace_recursive(v) for k, v in mapping.items()} - - case value: - return value - - class RedshiftClient(PostgreSQLClient): @contextlib.asynccontextmanager async def connect(self) -> typing.AsyncIterator[typing.Self]: @@ -283,116 +249,64 @@ class RedshiftHeartbeatDetails(BatchExportRangeHeartbeatDetails): pass -async def insert_records_to_redshift( - records: collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None], - redshift_client: RedshiftClient, - schema: str | None, - table: str, - heartbeater: Heartbeater, - heartbeat_details: RedshiftHeartbeatDetails, - data_interval_start: dt.datetime | None, - data_interval_end: dt.datetime, - batch_size: int = 100, - use_super: bool = False, - known_super_columns: list[str] | None = None, -) -> int: - """Execute an INSERT query with given Redshift connection. - - The recommended way to insert multiple values into Redshift is using a COPY statement (see: - https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html). However, Redshift cannot COPY from local - files like Postgres, but only from files in S3 or executing commands in SSH hosts. Setting that up would - add complexity and require more configuration from the user compared to the old Redshift export plugin. - For this reasons, we are going with basic INSERT statements for now, and we can migrate to COPY from S3 - later if the need arises. - - Arguments: - record: A dictionary representing the record to insert. Each key should correspond to a column - in the destination table. - redshift_connection: A connection to Redshift setup by psycopg2. - schema: The schema that contains the table where to insert the record. - table: The name of the table where to insert the record. - batch_size: Number of records to insert in batch. Setting this too high could - make us go OOM or exceed Redshift's SQL statement size limit (16MB). Setting this too low - can significantly affect performance due to Redshift's poor handling of INSERTs. - """ - first_value, records_iterator = await apeek_first_and_rewind(records) - if first_value is None: - return 0 - - first_record_batch, _inserted_at = first_value - columns = first_record_batch.keys() - - if schema: - table_identifier = sql.Identifier(schema, table) - else: - table_identifier = sql.Identifier(table) - - pre_query = sql.SQL("INSERT INTO {table} ({fields}) VALUES").format( - table=table_identifier, - fields=sql.SQL(", ").join(map(sql.Identifier, columns)), - ) - placeholders: list[sql.Composable] = [] - for column in columns: - if use_super is True and known_super_columns is not None and column in known_super_columns: - placeholders.append(sql.SQL("JSON_PARSE({placeholder})").format(placeholder=sql.Placeholder(column))) - else: - placeholders.append(sql.Placeholder(column)) - - template = sql.SQL("({})").format(sql.SQL(", ").join(placeholders)) - rows_exported = get_rows_exported_metric() - - total_rows_exported = 0 - - async with redshift_client.connection.transaction(): - async with redshift_client.async_client_cursor() as cursor: - batch = [] - pre_query_str = pre_query.as_string(cursor).encode("utf-8") - - async def flush_to_redshift(batch): - nonlocal total_rows_exported - - values = b",".join(batch).replace(b" E'", b" '") - await cursor.execute(pre_query_str + values) - rows_exported.add(len(batch)) - total_rows_exported += len(batch) - # It would be nice to record BYTES_EXPORTED for Redshift, but it's not worth estimating - # the byte size of each batch the way things are currently written. We can revisit this - # in the future if we decide it's useful enough. - - batch_start_inserted_at = None - async for record, _inserted_at in records_iterator: - if batch_start_inserted_at is None: - batch_start_inserted_at = _inserted_at - - for column in columns: - if known_super_columns is not None and column in known_super_columns: - record[column] = json.dumps(record[column], ensure_ascii=False) - - batch.append(cursor.mogrify(template, record).encode("utf-8")) - if len(batch) < batch_size: - continue - - await flush_to_redshift(batch) - - last_date_range = (batch_start_inserted_at, _inserted_at) - heartbeat_details.track_done_range(last_date_range, data_interval_start) - heartbeater.set_from_heartbeat_details(heartbeat_details) - - batch_start_inserted_at = None - batch = [] - - if len(batch) > 0 and batch_start_inserted_at: - await flush_to_redshift(batch) +class RedshiftConsumer(Consumer): + def __init__( + self, + heartbeater: Heartbeater, + heartbeat_details: RedshiftHeartbeatDetails, + data_interval_start: dt.datetime | str | None, + data_interval_end: dt.datetime | str, + redshift_client: RedshiftClient, + redshift_table: str, + ): + """Implementation of a record batch consumer for Redshift batch export. + + This consumer will execute an INSERT query on every flush using provided + Redshift client. The recommended way to insert multiple values into Redshift + is using a COPY statement (see: + https://docs.aws.amazon.com/redshift/latest/dg/r_COPY.html). However, + Redshift cannot COPY from local files like Postgres, but only from files in + S3 or executing commands in SSH hosts. Setting that up would add complexity + and require more configuration from the user compared to the old Redshift + export plugin. For these reasons, we are going with basic INSERT statements + for now, but should eventually migrate to COPY from S3 for performance. + """ + super().__init__( + heartbeater, + heartbeat_details, + data_interval_start, + data_interval_end, + writer_format=WriterFormat.REDSHIFT_INSERT, + ) + self.redshift_client = redshift_client + self.redshift_table = redshift_table - last_date_range = (batch_start_inserted_at, _inserted_at) + async def flush( + self, + batch_export_file: BatchExportTemporaryFile, + records_since_last_flush: int, + bytes_since_last_flush: int, + flush_counter: int, + last_date_range: DateRange, + is_last: bool, + error: Exception | None, + ): + await self.logger.adebug( + "Loading %s records in query of size %s bytes to Redshift table '%s'", + records_since_last_flush, + bytes_since_last_flush, + self.redshift_table, + ) - heartbeat_details.track_done_range(last_date_range, data_interval_start) - heartbeater.set_from_heartbeat_details(heartbeat_details) + async with self.redshift_client.async_client_cursor() as cursor: + async with self.redshift_client.connection.transaction(): + await cursor.execute(batch_export_file.read()) - heartbeat_details.complete_done_ranges(data_interval_end) - heartbeater.set_from_heartbeat_details(heartbeat_details) + await self.logger.adebug("Loaded %s to Redshift table '%s'", records_since_last_flush, self.redshift_table) + self.rows_exported_counter.add(records_since_last_flush) + self.bytes_exported_counter.add(bytes_since_last_flush) - return total_rows_exported + self.heartbeat_details.track_done_range(last_date_range, self.data_interval_start) @dataclasses.dataclass @@ -438,7 +352,7 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records async with ( Heartbeater() as heartbeater, set_status_to_running_task(run_id=inputs.run_id, logger=logger), - get_client(team_id=inputs.team_id) as client, + get_client(team_id=inputs.team_id, max_block_size=10) as client, ): if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") @@ -474,41 +388,28 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end) full_range = (data_interval_start, data_interval_end) - queue, produce_task = start_produce_batch_export_record_batches( - client=client, + queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_REDSHIFT_RECORD_BATCH_QUEUE_MAX_SIZE_BYTES) + producer = Producer(clickhouse_client=client) + producer_task = producer.start( + queue=queue, model_name=model_name, is_backfill=inputs.is_backfill, team_id=inputs.team_id, full_range=full_range, done_ranges=done_ranges, - exclude_events=inputs.exclude_events, - include_events=inputs.include_events, fields=fields, destination_default_fields=redshift_default_fields(), + exclude_events=inputs.exclude_events, + include_events=inputs.include_events, extra_query_parameters=extra_query_parameters, ) - - get_schema_task = asyncio.create_task(queue.get_schema()) - await asyncio.wait( - [get_schema_task, produce_task], - return_when=asyncio.FIRST_COMPLETED, - ) - - # Finishing producing happens sequentially after putting to queue and setting the schema. - # So, either we finished producing and setting the schema tasks, or we finished without - # putting anything in the queue. - if get_schema_task.done(): - # In the first case, we'll land here. - # The schema is available, and the queue is not empty, so we can start the batch export. - record_batch_schema = get_schema_task.result() - else: - # In the second case, we'll land here: We finished producing without putting anything. - # Since we finished producing with an empty queue, there is nothing to batch export. - # We could have also failed, so we need to re-raise that exception to allow a retry if - # that's the case. - await raise_on_produce_task_failure(produce_task) + record_batch_schema = await wait_for_schema_or_producer(queue, producer_task) + if record_batch_schema is None: return 0 + record_batch_schema = pa.schema( + [field.with_nullable(True) for field in record_batch_schema if field.name != "_inserted_at"] + ) known_super_columns = ["properties", "set", "set_once", "person_properties"] if inputs.properties_data_type != "varchar": properties_type = "SUPER" @@ -564,52 +465,30 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records ): schema_columns = {field[0] for field in table_fields} - def map_to_record(row: dict) -> tuple[dict, dt.datetime]: - """Map row to a record to insert to Redshift.""" - record = {k: v for k, v in row.items() if k in schema_columns} - - for column in known_super_columns: - if record.get(column, None) is not None: - # TODO: We should be able to save a json.loads here. - record[column] = remove_escaped_whitespace_recursive(json.loads(record[column])) - - if isinstance(row["_inserted_at"], int): - inserted_at = dt.datetime.fromtimestamp(row["_inserted_at"]) - else: - inserted_at = row["_inserted_at"] - - return record, inserted_at - - async def record_generator() -> ( - collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None] - ): - while not queue.empty() or not produce_task.done(): - try: - record_batch = queue.get_nowait() - except asyncio.QueueEmpty: - if produce_task.done(): - await logger.adebug( - "Empty queue with no more events being produced, closing consumer loop" - ) - return - else: - await asyncio.sleep(0.1) - continue - - for record in record_batch.to_pylist(): - yield map_to_record(record) - - records_completed = await insert_records_to_redshift( - record_generator(), - redshift_client, - inputs.schema, - redshift_stage_table if requires_merge else redshift_table, + consumer = RedshiftConsumer( heartbeater=heartbeater, - use_super=properties_type == "SUPER", - known_super_columns=known_super_columns, heartbeat_details=details, - data_interval_start=data_interval_start, data_interval_end=data_interval_end, + data_interval_start=data_interval_start, + redshift_client=redshift_client, + redshift_table=redshift_stage_table if requires_merge else redshift_table, + ) + records_completed = await run_consumer( + consumer=consumer, + queue=queue, + producer_task=producer_task, + schema=record_batch_schema, + max_bytes=settings.BATCH_EXPORT_REDSHIFT_UPLOAD_CHUNK_SIZE_BYTES, + json_columns=known_super_columns, + writer_file_kwargs={ + "redshift_table": redshift_stage_table if requires_merge else redshift_table, + "redshift_schema": inputs.schema, + "table_columns": schema_columns, + "known_json_columns": set(known_super_columns), + "use_super": properties_type == "SUPER", + "redshift_client": redshift_client, + }, + multiple_files=True, ) if requires_merge: diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 927d6436d634f..55dacb59e60e4 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -44,7 +44,7 @@ Consumer, Producer, RecordBatchQueue, - run_consumer_loop, + run_consumer, wait_for_schema_or_producer, ) from posthog.temporal.batch_exports.temporary_file import ( @@ -469,10 +469,17 @@ def __init__( heartbeater: Heartbeater, heartbeat_details: S3HeartbeatDetails, data_interval_start: dt.datetime | str | None, + data_interval_end: dt.datetime | str, writer_format: WriterFormat, s3_upload: S3MultiPartUpload, ): - super().__init__(heartbeater, heartbeat_details, data_interval_start, writer_format) + super().__init__( + heartbeater=heartbeater, + heartbeat_details=heartbeat_details, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, + writer_format=writer_format, + ) self.heartbeat_details: S3HeartbeatDetails = heartbeat_details self.s3_upload = s3_upload @@ -703,18 +710,20 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: ) async with s3_upload as s3_upload: - records_completed = await run_consumer_loop( - queue=queue, - consumer_cls=S3Consumer, - producer_task=producer_task, + consumer = S3Consumer( heartbeater=heartbeater, heartbeat_details=details, data_interval_end=data_interval_end, data_interval_start=data_interval_start, - schema=record_batch_schema, writer_format=WriterFormat.from_str(inputs.file_format, "S3"), - max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, s3_upload=s3_upload, + ) + records_completed = await run_consumer( + consumer=consumer, + queue=queue, + producer_task=producer_task, + schema=record_batch_schema, + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, include_inserted_at=True, writer_file_kwargs={"compression": inputs.compression}, ) diff --git a/posthog/temporal/batch_exports/spmc.py b/posthog/temporal/batch_exports/spmc.py index 53171543db480..f2388a4ed5fa1 100644 --- a/posthog/temporal/batch_exports/spmc.py +++ b/posthog/temporal/batch_exports/spmc.py @@ -178,12 +178,14 @@ def __init__( heartbeater: Heartbeater, heartbeat_details: BatchExportRangeHeartbeatDetails, data_interval_start: dt.datetime | str | None, + data_interval_end: dt.datetime | str, writer_format: WriterFormat, ): self.flush_start_event = asyncio.Event() self.heartbeater = heartbeater self.heartbeat_details = heartbeat_details self.data_interval_start = data_interval_start + self.data_interval_end = data_interval_end self.writer_format = writer_format self.logger = logger @@ -197,6 +199,35 @@ def bytes_exported_counter(self) -> temporalio.common.MetricCounter: """Access the bytes exported metric counter.""" return get_bytes_exported_metric() + def create_consumer_task( + self, + queue: RecordBatchQueue, + producer_task: asyncio.Task, + max_bytes: int, + schema: pa.Schema, + json_columns: collections.abc.Sequence[str], + multiple_files: bool = False, + include_inserted_at: bool = False, + task_name: str = "record_batch_consumer", + **kwargs, + ) -> asyncio.Task: + """Create a record batch consumer task.""" + consumer_task = asyncio.create_task( + self.start( + queue=queue, + producer_task=producer_task, + max_bytes=max_bytes, + schema=schema, + json_columns=json_columns, + multiple_files=multiple_files, + include_inserted_at=include_inserted_at, + **kwargs, + ), + name=task_name, + ) + + return consumer_task + @abc.abstractmethod async def flush( self, @@ -249,25 +280,29 @@ async def start( Returns: Total number of records in all consumed record batches. """ - await logger.adebug("Starting record batch consumer") - schema = cast_record_batch_schema_json_columns(schema, json_columns=json_columns) writer = get_batch_export_writer(self.writer_format, self.flush, schema=schema, max_bytes=max_bytes, **kwargs) record_batches_count = 0 + record_batches_count_total = 0 records_count = 0 - await self.logger.adebug("Starting record batch writing loop") + await self.logger.adebug("Consuming record batches from producer %s", producer_task.get_name()) writer._batch_export_file = await asyncio.to_thread(writer.create_temporary_file) async for record_batch in self.generate_record_batches_from_queue(queue, producer_task): record_batches_count += 1 + record_batches_count_total += 1 record_batch = cast_record_batch_json_columns(record_batch, json_columns=json_columns) await writer.write_record_batch(record_batch, flush=False, include_inserted_at=include_inserted_at) if writer.should_flush(): + await self.logger.adebug( + "Flushing %s records from %s record batches", writer.records_since_last_flush, record_batches_count + ) + records_count += writer.records_since_last_flush if multiple_files: @@ -281,9 +316,15 @@ async def start( record_batches_count = 0 records_count += writer.records_since_last_flush + + await self.logger.adebug( + "Finished consuming %s records from %s record batches, will flush any pending data", + records_count, + record_batches_count_total, + ) + await writer.close_temporary_file() - await self.logger.adebug("Consumed %s records", records_count) self.heartbeater.set_from_heartbeat_details(self.heartbeat_details) return records_count @@ -308,6 +349,11 @@ async def generate_record_batches_from_queue( yield record_batch + def complete_heartbeat(self): + """Complete this consumer's heartbeats.""" + self.heartbeat_details.complete_done_ranges(self.data_interval_end) + self.heartbeater.set_from_heartbeat_details(self.heartbeat_details) + class RecordBatchConsumerRetryableExceptionGroup(ExceptionGroup): """ExceptionGroup raised when at least one task fails with a retryable exception.""" @@ -323,24 +369,19 @@ def derive(self, excs): return RecordBatchConsumerNonRetryableExceptionGroup(self.message, excs) -async def run_consumer_loop( +async def run_consumer( queue: RecordBatchQueue, - consumer_cls: type[Consumer], + consumer: Consumer, producer_task: asyncio.Task, - heartbeater: Heartbeater, - heartbeat_details: BatchExportRangeHeartbeatDetails, - data_interval_end: dt.datetime | str, - data_interval_start: dt.datetime | str | None, - schema: pa.Schema, - writer_format: WriterFormat, max_bytes: int, + schema: pa.Schema, json_columns: collections.abc.Sequence[str] = ("properties", "person_properties", "set", "set_once"), - writer_file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None, multiple_files: bool = False, + writer_file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None, include_inserted_at: bool = False, **kwargs, ) -> int: - """Run record batch consumers in a loop. + """Run one record batch consumer. When a consumer starts flushing, a new consumer will be started, and so on in a loop. Once there is nothing left to consumer from the `RecordBatchQueue`, no @@ -362,7 +403,6 @@ async def run_consumer_loop( """ consumer_tasks_pending: set[asyncio.Task] = set() consumer_tasks_done = set() - consumer_number = 0 records_completed = 0 def consumer_done_callback(task: asyncio.Task): @@ -378,27 +418,22 @@ def consumer_done_callback(task: asyncio.Task): consumer_tasks_pending.remove(task) consumer_tasks_done.add(task) - await logger.adebug("Starting record batch consumer loop") - - consumer = consumer_cls(heartbeater, heartbeat_details, data_interval_start, writer_format, **kwargs) - consumer_task = asyncio.create_task( - consumer.start( - queue=queue, - producer_task=producer_task, - max_bytes=max_bytes, - schema=schema, - json_columns=json_columns, - multiple_files=multiple_files, - include_inserted_at=include_inserted_at, - **writer_file_kwargs or {}, - ), - name=f"record_batch_consumer_{consumer_number}", + await logger.adebug("Starting record batch consumer") + + consumer_task = consumer.create_consumer_task( + queue=queue, + producer_task=producer_task, + max_bytes=max_bytes, + schema=schema, + json_columns=json_columns, + multiple_files=multiple_files, + include_inserted_at=include_inserted_at, + **writer_file_kwargs or {}, ) consumer_tasks_pending.add(consumer_task) consumer_task.add_done_callback(consumer_done_callback) - consumer_number += 1 - await asyncio.wait([consumer_task]) + await asyncio.wait(consumer_tasks_pending) if consumer_task.done(): consumer_task_exception = consumer_task.exception() @@ -406,13 +441,10 @@ def consumer_done_callback(task: asyncio.Task): if consumer_task_exception is not None: raise consumer_task_exception - await logger.adebug("Finished consuming record batches") - await raise_on_task_failure(producer_task) - await logger.adebug("Successfully consumed all record batches") + await logger.adebug("Successfully finished record batch consumer") - heartbeat_details.complete_done_ranges(data_interval_end) - heartbeater.set_from_heartbeat_details(heartbeat_details) + consumer.complete_heartbeat() return records_completed diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 9b23b6f5c9692..de04733eb80c1 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -14,9 +14,11 @@ import brotli import orjson +import psycopg import pyarrow as pa import pyarrow.parquet as pq import structlog +from psycopg import sql from posthog.temporal.batch_exports.heartbeat import DateRange @@ -480,6 +482,7 @@ class WriterFormat(enum.StrEnum): JSONL = enum.auto() PARQUET = enum.auto() CSV = enum.auto() + REDSHIFT_INSERT = enum.auto() @staticmethod def from_str(format_str: str, destination: str): @@ -490,6 +493,8 @@ def from_str(format_str: str, destination: str): return WriterFormat.PARQUET case "CSV": return WriterFormat.CSV + case "REDSHIFT_INSERT": + return WriterFormat.REDSHIFT_INSERT case _: raise UnsupportedFileFormatError(format_str, destination) @@ -517,6 +522,13 @@ def get_batch_export_writer(writer_format: WriterFormat, flush_callable: FlushCa **kwargs, ) + case WriterFormat.REDSHIFT_INSERT: + return RedshiftInsertBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + **kwargs, + ) + class JSONLBatchExportWriter(BatchExportWriter): """A `BatchExportWriter` for JSONLines format. @@ -707,3 +719,145 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as Parquet.""" self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) + + +def remove_escaped_whitespace_recursive(value): + """Remove all escaped whitespace characters from given value. + + PostgreSQL supports constant escaped strings by appending an E' to each string that + contains whitespace in them (amongst other characters). See: + https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS-ESCAPE + + However, Redshift does not support this syntax. So, to avoid any escaping by + underlying PostgreSQL library, we remove the whitespace ourselves as defined in the + translation table WHITESPACE_TRANSLATE. + + This function is recursive just to be extremely careful and catch any whitespace that + may be sneaked in a dictionary key or sequence. + """ + match value: + case str(s): + return " ".join(s.replace("\b", " ").split()) + + case bytes(b): + return remove_escaped_whitespace_recursive(b.decode("utf-8")) + + case [*sequence]: + # mypy could be bugged as it's raising a Statement unreachable error. + # But we are definitely reaching this statement in tests; hence the ignore comment. + # Maybe: https://github.com/python/mypy/issues/16272. + return type(value)(remove_escaped_whitespace_recursive(sequence_value) for sequence_value in sequence) # type: ignore + + case set(elements): + return {remove_escaped_whitespace_recursive(element) for element in elements} + + case {**mapping}: + return {k: remove_escaped_whitespace_recursive(v) for k, v in mapping.items()} + + case value: + return value + + +class RedshiftInsertBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for Redshift INSERT queries. + + Arguments: + max_bytes: Redshift's SQL statement size limit is 16MB, so anything more than + that will result in an error. However, setthing `max_bytes` too low can + significantly affect performance due to Redshift's poor handling of INSERTs. + """ + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + schema: pa.Schema, + redshift_table: str, + redshift_schema: str | None, + table_columns: collections.abc.Sequence[str], + known_json_columns: collections.abc.Sequence[str], + use_super: bool, + redshift_client, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": None}, + ) + self.schema = schema + self.redshift_table = redshift_table + self.redshift_schema = redshift_schema + self.table_columns = table_columns + self.known_json_columns = known_json_columns + self.use_super = use_super + self.redshift_client = redshift_client + self._cursor: psycopg.AsyncClientCursor | None = None + self.first = True + + placeholders: list[sql.Composable] = [] + for column in table_columns: + if column in known_json_columns and use_super is True: + placeholders.append(sql.SQL("JSON_PARSE({placeholder})").format(placeholder=sql.Placeholder(column))) + else: + placeholders.append(sql.Placeholder(column)) + + self.template = sql.SQL("({})").format(sql.SQL(", ").join(placeholders)) + + def create_temporary_file(self) -> BatchExportTemporaryFile: + """On creating a temporary file, write first the start of a query.""" + file = super().create_temporary_file() + + if self.redshift_schema: + table_identifier = sql.Identifier(self.redshift_schema, self.redshift_table) + else: + table_identifier = sql.Identifier(self.redshift_table) + + pre_query_encoded = asyncio.run(self.get_encoded_pre_query(table_identifier)) + file.write(pre_query_encoded) + + return file + + async def get_encoded_pre_query(self, table_identifier: sql.Identifier) -> bytes: + """Encode and format the start of an INSERT INTO query.""" + pre_query = sql.SQL("INSERT INTO {table} ({fields}) VALUES").format( + table=table_identifier, + fields=sql.SQL(", ").join(map(sql.Identifier, self.table_columns)), + ) + + async with self.redshift_client.async_client_cursor() as cursor: + return pre_query.as_string(cursor).encode("utf-8") + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as values in an INSERT query.""" + for record_dict in record_batch.to_pylist(): + if not record_dict: + continue + + record = {} + for key, value in record_dict.items(): + if key not in self.table_columns: + continue + + record[key] = value + + if value is not None and key in self.known_json_columns: + record[key] = json.dumps(remove_escaped_whitespace_recursive(record[key]), ensure_ascii=False) + + encoded = asyncio.run(self.mogrify_record(record)) + + if self.first: + self.first = False + else: + self.batch_export_file.write(",") + + self.batch_export_file.write(encoded) + + async def mogrify_record(self, record: dict[str, typing.Any]) -> bytes: + """Produce encoded bytes from a record.""" + async with self.redshift_client.async_client_cursor() as cursor: + return cursor.mogrify(self.template, record).encode("utf-8").replace(b" E'", b" '") + + async def close_temporary_file(self): + """Ensure we mark next query as first after closing a file.""" + await super().close_temporary_file() + self.first = True diff --git a/posthog/temporal/common/clickhouse.py b/posthog/temporal/common/clickhouse.py index a326e29d9fccd..ea33cda5d356c 100644 --- a/posthog/temporal/common/clickhouse.py +++ b/posthog/temporal/common/clickhouse.py @@ -489,11 +489,12 @@ async def get_client( timeout = aiohttp.ClientTimeout(total=None, connect=None, sock_connect=30, sock_read=None) if team_id is None: - max_block_size = settings.CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT + default_max_block_size = settings.CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT else: - max_block_size = settings.CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES.get( + default_max_block_size = settings.CLICKHOUSE_MAX_BLOCK_SIZE_OVERRIDES.get( team_id, settings.CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT ) + max_block_size = kwargs.pop("max_block_size", None) or default_max_block_size if clickhouse_url is None: url = settings.CLICKHOUSE_OFFLINE_HTTP_URL diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index aaf4469435508..20c38545490b9 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -2,8 +2,8 @@ import json import operator import os -import warnings import uuid +import warnings import psycopg import pytest @@ -29,8 +29,8 @@ RedshiftInsertInputs, insert_into_redshift_activity, redshift_default_fields, - remove_escaped_whitespace_recursive, ) +from posthog.temporal.batch_exports.temporary_file import remove_escaped_whitespace_recursive from posthog.temporal.common.clickhouse import ClickHouseClient from posthog.temporal.tests.batch_exports.utils import mocked_start_batch_export_run from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse