diff --git a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py index 36ff3dd2f99a2c..b205bc3cc6ba52 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py +++ b/metadata-ingestion/src/datahub/ingestion/source/kafka/kafka.py @@ -4,9 +4,10 @@ import json import logging import random +import statistics from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Any, Dict, Iterable, List, Optional, Set, Type, cast +from typing import Any, Dict, Iterable, List, Optional, Type, cast import avro.io import avro.schema @@ -55,6 +56,7 @@ ) from datahub.ingestion.api.workunit import MetadataWorkUnit from datahub.ingestion.source.common.subtypes import DatasetSubTypes +from datahub.ingestion.source.ge_profiling_config import GEProfilingBaseConfig from datahub.ingestion.source.kafka.kafka_schema_registry_base import ( KafkaSchemaRegistryBase, ) @@ -67,19 +69,25 @@ StatefulIngestionConfigBase, StatefulIngestionSourceBase, ) -from datahub.metadata.com.linkedin.pegasus2avro.common import Status -from datahub.metadata.com.linkedin.pegasus2avro.metadata.snapshot import DatasetSnapshot from datahub.metadata.com.linkedin.pegasus2avro.mxe import MetadataChangeEvent from datahub.metadata.schema_classes import ( BrowsePathsClass, + CalendarIntervalClass, DataPlatformInstanceClass, DatasetFieldProfileClass, DatasetProfileClass, DatasetPropertiesClass, + DatasetSnapshotClass, KafkaSchemaClass, OwnershipSourceTypeClass, + PartitionSpecClass, + PartitionTypeClass, + QuantileClass, SchemaMetadataClass, + StatusClass, SubTypesClass, + TimeWindowClass, + TimeWindowSizeClass, ) from datahub.utilities.mapping import Constants, OperationProcessor from datahub.utilities.registries.domain_registry import DomainRegistry @@ -97,6 +105,17 @@ class KafkaTopicConfigKeys(StrEnum): UNCLEAN_LEADER_ELECTION_CONFIG = "unclean.leader.election.enable" +class ProfilerConfig(GEProfilingBaseConfig): + sample_size: pydantic.PositiveInt = pydantic.Field( + default=100, + description="Number of messages to sample for profiling", + ) + max_sample_time_seconds: pydantic.PositiveInt = pydantic.Field( + default=60, + description="Maximum time to spend sampling messages in seconds", + ) + + class KafkaSourceConfig( StatefulIngestionConfigBase, DatasetSourceConfigMixin, @@ -153,14 +172,9 @@ class KafkaSourceConfig( default=False, description="Enables ingesting schemas from schema registry as separate entities, in addition to the topics", ) - enable_sample_data: bool = pydantic.Field( - default=False, description="Whether to collect sample messages from topics" - ) - sample_size: int = pydantic.Field( - default=100, description="Number of sample messages to collect per topic" - ) - sample_timeout_seconds: int = pydantic.Field( - default=5, description="Timeout in seconds when collecting sample messages" + profiling: ProfilerConfig = pydantic.Field( + default=ProfilerConfig(), + description="Settings for message sampling and profiling", ) @@ -442,16 +456,184 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: "subject", f"Exception while extracting topic {subject}: {e}" ) - def _process_message_part( - self, data: Any, prefix: str, topic: str, is_key: bool = False - ) -> Optional[Any]: - """Process either key or value part of a message using schema registry for decoding.""" - if data is None: + def _process_sample_data( + self, + samples: List[Dict[str, Any]], + schema_metadata: Optional[SchemaMetadataClass] = None, + ) -> Dict[str, Any]: + """Process sample data to extract field information.""" + field_sample_map: Dict[str, List[str]] = {} + + # Initialize from schema if available + if schema_metadata and schema_metadata.fields: + for schema_field in schema_metadata.fields: + field_sample_map[schema_field.fieldPath] = [] + + # Process each sample + for sample in samples: + for field_name, value in sample.items(): + if field_name not in field_sample_map: + field_sample_map[field_name] = [] + if value is not None: + field_sample_map[field_name].append(str(value)) + + return field_sample_map + + def _create_profile_class( + self, + field_sample_map: Dict[str, List[str]], + row_count: int, + ) -> DatasetProfileClass: + """Create profile class from processed samples with complete statistical metrics.""" + field_profiles = [] + + for field_path, samples in field_sample_map.items(): + if not samples: + continue + + # Convert string samples to numeric where possible + numeric_samples = [] + for sample in samples: + try: + numeric_samples.append(float(sample)) + except (ValueError, TypeError): + pass + + sorted_samples = sorted(numeric_samples) if numeric_samples else [] + q1_idx = len(sorted_samples) // 4 if len(sorted_samples) >= 4 else 0 + q3_idx = (3 * len(sorted_samples)) // 4 if len(sorted_samples) >= 4 else 0 + + # Create the profile with mandatory and optional fields + profile = DatasetFieldProfileClass( + fieldPath=field_path, + # Required fields with fallbacks + sampleValues=random.sample(samples, min(3, len(samples))) + if samples + else None, + nullCount=row_count - len(samples), + uniqueCount=len(set(samples)), + # Numeric statistics + min=str(min(numeric_samples)) if numeric_samples else None, + max=str(max(numeric_samples)) if numeric_samples else None, + mean=str(statistics.mean(numeric_samples)) if numeric_samples else None, + median=str(statistics.median(numeric_samples)) + if numeric_samples + else None, + stdev=str(statistics.stdev(numeric_samples)) + if len(numeric_samples) > 1 + else None, + # Quartile statistics + quantiles=[ + QuantileClass( + quantile=str(0.25), value=str(sorted_samples[q1_idx]) + ), + QuantileClass( + quantile=str(0.75), value=str(sorted_samples[q3_idx]) + ), + ] + if len(sorted_samples) >= 4 + else None, + # Set unused fields to None + histogram=None, + distinctValueFrequencies=None, + ) + field_profiles.append(profile) + + # Create the dataset profile with timing information + timestamp_millis = int(datetime.now().timestamp() * 1000) + return DatasetProfileClass( + timestampMillis=timestamp_millis, + rowCount=row_count, + columnCount=len(field_profiles), + fieldProfiles=field_profiles, + # Add time window information + eventGranularity=TimeWindowSizeClass( + unit=CalendarIntervalClass.SECOND, + multiple=self.source_config.profiling.max_sample_time_seconds, + ), + # Add partition specification + partitionSpec=PartitionSpecClass( + partition="SAMPLE", + type=PartitionTypeClass.QUERY, + timePartition=TimeWindowClass( + startTimeMillis=timestamp_millis, + length=TimeWindowSizeClass( + unit=CalendarIntervalClass.SECOND, + multiple=self.source_config.profiling.max_sample_time_seconds, + ), + ), + ), + ) + + def _get_sample_data( + self, + topic: str, + timeout_seconds: int, + max_samples: int, + ) -> List[Dict[str, Any]]: + """Get sample messages from Kafka topic.""" + samples: List[Dict[str, Any]] = [] + self.consumer.subscribe([topic]) + + end_time = datetime.now() + timedelta(seconds=timeout_seconds) + + try: + while datetime.now() < end_time and len(samples) < max_samples: + msg = self.consumer.poll(timeout=1.0) + if msg is None or msg.error(): + continue + + sample_data = self._process_message(msg, topic) + if sample_data: + samples.append(sample_data) + + return samples + + finally: + self.consumer.unsubscribe() + + def _process_message( + self, msg: confluent_kafka.Message, topic: str + ) -> Optional[Dict[str, Any]]: + """Process a single Kafka message into a sample.""" + try: + key = msg.key() if callable(msg.key) else msg.key + value = msg.value() if callable(msg.value) else msg.value + + sample = { + "offset": msg.offset(), + "timestamp": datetime.fromtimestamp(msg.timestamp()[1]).isoformat(), + } + + # Process key if present + if key: + processed_key = self._process_payload(key, topic, is_key=True) + if isinstance(processed_key, dict): + sample.update({f"key_{k}": v for k, v in processed_key.items()}) + else: + sample["key"] = processed_key + + # Process value + if value: + processed_value = self._process_payload(value, topic, is_key=False) + if isinstance(processed_value, dict): + sample.update(processed_value) + else: + sample["value"] = processed_value + + return sample + except Exception as e: + logger.warning(f"Error processing message: {e}") + return None + + def _process_payload(self, payload: Any, topic: str, is_key: bool = False) -> Any: + """Process a message payload (key or value) with Avro support.""" + if payload is None: return None - if isinstance(data, bytes): + if isinstance(payload, bytes): try: - # Get schema metadata + # First try Avro decoding with schema registry schema_metadata = self.schema_registry_client.get_schema_metadata( topic, make_data_platform_urn(self.platform), False ) @@ -470,7 +652,7 @@ def _process_message_part( # Parse schema and create reader schema = avro.schema.parse(schema_str) # Decode Avro data - first 5 bytes are magic byte and schema ID - decoder = avro.io.BinaryDecoder(io.BytesIO(data[5:])) + decoder = avro.io.BinaryDecoder(io.BytesIO(payload[5:])) reader = avro.io.DatumReader(schema) decoded_value = reader.read(decoder) @@ -481,11 +663,11 @@ def _process_message_part( return flatten_json(decoded_value) return decoded_value except Exception as e: - logger.warning(f"Failed to decode Avro message: {e}") + logger.debug(f"Failed to decode Avro message: {e}") # Fallback to JSON decode if no schema or Avro decode fails try: - decoded = json.loads(data.decode("utf-8")) + decoded = json.loads(payload.decode("utf-8")) if isinstance(decoded, (dict, list)): if isinstance(decoded, list): decoded = {"item": decoded} @@ -493,213 +675,47 @@ def _process_message_part( return decoded except Exception: # If JSON fails, use base64 as last resort - return base64.b64encode(data).decode("utf-8") + return base64.b64encode(payload).decode("utf-8") except Exception as e: - logger.warning(f"Failed to process message part: {e}") - return base64.b64encode(data).decode("utf-8") + logger.debug(f"Failed to process message part: {e}") + return base64.b64encode(payload).decode("utf-8") - return data + return payload - def get_sample_messages(self, topic: str) -> List[Dict[str, Any]]: - """ - Collects sample messages from a Kafka topic, handling both key and value fields. - """ - samples: List[Dict[str, Any]] = [] - try: - self.consumer.subscribe([topic]) - - # Poll for messages until timeout or we get desired number of samples - end_time = datetime.now() + timedelta( - seconds=self.source_config.sample_timeout_seconds - ) - - while datetime.now() < end_time: - msg = self.consumer.poll(timeout=1.0) - - if msg is None: - continue - - if msg.error(): - logger.warning(f"Error while consuming from {topic}: {msg.error()}") - break - - try: - # Process both key and value - key = msg.key() if callable(msg.key) else msg.key - value = msg.value() if callable(msg.value) else msg.value - processed_key = self._process_message_part( - key, "key", topic, is_key=True - ) - processed_value = self._process_message_part( - value, "value", topic, is_key=False - ) - - msg_timestamp = msg.timestamp()[1] - timestamp_dt = datetime.fromtimestamp( - msg_timestamp / 1000.0 - if msg_timestamp > 1e10 - else msg_timestamp - ) - - sample = { - "offset": msg.offset(), - "timestamp": timestamp_dt.isoformat(), - } - - # Add key and value data with proper prefixing - if processed_key is not None: - if isinstance(processed_key, dict): - # Don't prefix with 'key.' - sample.update(processed_key) - else: - sample["key"] = processed_key - - if processed_value is not None: - if isinstance(processed_value, dict): - # Add value fields without prefix - sample.update(processed_value) - else: - sample["value"] = processed_value - - samples.append(sample) - - if len(samples) >= self.source_config.sample_size: - break - - except Exception as e: - logger.warning(f"Failed to decode message from {topic}: {e}") - - except Exception as e: - logger.warning(f"Failed to collect samples from {topic}: {e}") - finally: - self.consumer.unsubscribe() - - return samples - - def _process_sample_data( + def get_profiling_workunit( self, - samples: List[Dict[str, Any]], + dataset_urn: str, + topic: str, schema_metadata: Optional[SchemaMetadataClass] = None, - ) -> Dict[str, Any]: - """Process sample data to extract field information from both key and value schemas.""" - all_keys: Set[str] = set() - field_sample_map: Dict[str, List[str]] = {} - key_field_path: Optional[str] = None + ) -> Iterable[MetadataWorkUnit]: + """Generate profiling workunit for a topic.""" - # Initialize from schema if available - if schema_metadata is not None and isinstance( - schema_metadata.platformSchema, KafkaSchemaClass - ): - # Find the key field path from schema metadata fields - key_field = next( - ( - schema_field - for schema_field in (schema_metadata.fields or []) - if schema_field.fieldPath.endswith("[key=True]") - ), - None, + try: + samples = self._get_sample_data( + topic, + self.source_config.profiling.max_sample_time_seconds, + self.source_config.profiling.sample_size, ) - if key_field: - key_field_path = key_field.fieldPath - all_keys.add(key_field_path) - field_sample_map[key_field_path] = [] - - # Handle all schema fields (both key and value) - for schema_field in schema_metadata.fields or []: - field_path = schema_field.fieldPath - if field_path not in field_sample_map: - field_sample_map[field_path] = [] - all_keys.add(field_path) - - # Process samples - for sample in samples: - # Process each field in the sample - for field_name, value in sample.items(): - if field_name not in ["offset", "timestamp"]: - # For sample data, we need to map the simplified field names back to full paths - matching_schema_field = None - if schema_metadata and schema_metadata.fields: - clean_field = clean_field_path(field_name, preserve_types=False) - - # Special handling for key field - if field_name == "key" and key_field_path: - matching_schema_field = next( - schema_field - for schema_field in schema_metadata.fields - if schema_field.fieldPath == key_field_path - ) - else: - # Find matching schema field by comparing the end of the path - for schema_field in schema_metadata.fields: - if ( - clean_field_path( - schema_field.fieldPath, preserve_types=False - ) - == clean_field - ): - matching_schema_field = schema_field - break - - # Use the full path from schema if found, otherwise use original field name - field_path = ( - matching_schema_field.fieldPath - if matching_schema_field - else field_name - ) - if field_path not in field_sample_map: - field_sample_map[field_path] = [] - all_keys.add(field_path) - field_sample_map[field_path].append(str(value)) + if samples: + field_sample_map = self._process_sample_data(samples, schema_metadata) + profile = self._create_profile_class(field_sample_map, len(samples)) - return {"all_keys": all_keys, "field_sample_map": field_sample_map} + yield MetadataChangeProposalWrapper( + entityUrn=dataset_urn, aspect=profile + ).as_workunit() - def _create_profile_data( - self, all_keys: set, field_sample_map: Dict[str, List[str]], sample_count: int - ) -> DatasetProfileClass: - """Create profile data from processed samples.""" - timestamp_millis = int(datetime.now().timestamp() * 1000) - return DatasetProfileClass( - timestampMillis=timestamp_millis, - columnCount=len(all_keys), - fieldProfiles=[ - DatasetFieldProfileClass( - fieldPath=field_name, - sampleValues=random.sample( - field_samples, min(3, len(field_samples)) - ), - ) - for field_name, field_samples in field_sample_map.items() - ], - ) - - def create_samples_wu( - self, - entity_urn: str, - topic: str, - schema_metadata: Optional[SchemaMetadataClass] = None, - ) -> Iterable[MetadataWorkUnit]: - """Create samples work unit incorporating both schema fields and sample values.""" - samples = self.get_sample_messages(topic) - if samples: - processed_data = self._process_sample_data(samples, schema_metadata) - profile_data = self._create_profile_data( - processed_data["all_keys"], - processed_data["field_sample_map"], - len(samples), - ) - yield MetadataChangeProposalWrapper( - entityUrn=entity_urn, aspect=profile_data - ).as_workunit() + except Exception as e: + logger.warning(f"Error generating profile for topic {topic}: {e}") def get_dataset_description( self, dataset_name: str, - dataset_snapshot: DatasetSnapshot, + dataset_snapshot: DatasetSnapshotClass, custom_props: Dict[str, str], schema_metadata: Optional[SchemaMetadataClass], - ) -> DatasetSnapshot: + ) -> DatasetSnapshotClass: AVRO = "AVRO" description: Optional[str] = None if ( @@ -794,9 +810,9 @@ def _extract_record( platform_instance=self.source_config.platform_instance, env=self.source_config.env, ) - dataset_snapshot = DatasetSnapshot( + dataset_snapshot = DatasetSnapshotClass( urn=dataset_urn, - aspects=[Status(removed=False)], # we append to this list later on + aspects=[StatusClass(removed=False)], ) if schema_metadata is not None: @@ -869,9 +885,9 @@ def _extract_record( ) # 9. Emit sample values - if not is_subject and self.source_config.enable_sample_data: - yield from self.create_samples_wu( - entity_urn=dataset_urn, + if not is_subject and self.source_config.profiling.enabled: + yield from self.get_profiling_workunit( + dataset_urn=dataset_urn, topic=topic, schema_metadata=schema_metadata, )