Skip to content

Commit

Permalink
feat(ingest/iceberg): Iceberg performance improvement (multi-threadin…
Browse files Browse the repository at this point in the history
  • Loading branch information
skrydal authored Nov 18, 2024
1 parent 435792c commit 2527f54
Show file tree
Hide file tree
Showing 7 changed files with 821 additions and 152 deletions.
178 changes: 130 additions & 48 deletions metadata-ingestion/src/datahub/ingestion/source/iceberg/iceberg.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import json
import logging
import threading
import uuid
from typing import Any, Dict, Iterable, List, Optional

from pyiceberg.catalog import Catalog
from pyiceberg.exceptions import NoSuchIcebergTableError
from pyiceberg.exceptions import (
NoSuchIcebergTableError,
NoSuchNamespaceError,
NoSuchPropertyException,
)
from pyiceberg.schema import Schema, SchemaVisitorPerPrimitiveType, visit
from pyiceberg.table import Table
from pyiceberg.typedef import Identifier
Expand Down Expand Up @@ -75,6 +80,8 @@
OwnershipClass,
OwnershipTypeClass,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.threaded_iterator_executor import ThreadedIteratorExecutor

LOGGER = logging.getLogger(__name__)
logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
Expand Down Expand Up @@ -130,74 +137,149 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]:
]

def _get_datasets(self, catalog: Catalog) -> Iterable[Identifier]:
for namespace in catalog.list_namespaces():
yield from catalog.list_tables(namespace)
namespaces = catalog.list_namespaces()
LOGGER.debug(
f"Retrieved {len(namespaces)} namespaces, first 10: {namespaces[:10]}"
)
self.report.report_no_listed_namespaces(len(namespaces))
tables_count = 0
for namespace in namespaces:
try:
tables = catalog.list_tables(namespace)
tables_count += len(tables)
LOGGER.debug(
f"Retrieved {len(tables)} tables for namespace: {namespace}, in total retrieved {tables_count}, first 10: {tables[:10]}"
)
self.report.report_listed_tables_for_namespace(
".".join(namespace), len(tables)
)
yield from tables
except NoSuchNamespaceError:
self.report.report_warning(
"no-such-namespace",
f"Couldn't list tables for namespace {namespace} due to NoSuchNamespaceError exception",
)
LOGGER.warning(
f"NoSuchNamespaceError exception while trying to get list of tables from namespace {namespace}, skipping it",
)
except Exception as e:
self.report.report_failure(
"listing-tables-exception",
f"Couldn't list tables for namespace {namespace} due to {e}",
)
LOGGER.exception(
f"Unexpected exception while trying to get list of tables for namespace {namespace}, skipping it"
)

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
try:
catalog = self.config.get_catalog()
except Exception as e:
LOGGER.error("Failed to get catalog", exc_info=True)
self.report.report_failure("get-catalog", f"Failed to get catalog: {e}")
return
thread_local = threading.local()

for dataset_path in self._get_datasets(catalog):
def _process_dataset(dataset_path: Identifier) -> Iterable[MetadataWorkUnit]:
LOGGER.debug(f"Processing dataset for path {dataset_path}")
dataset_name = ".".join(dataset_path)
if not self.config.table_pattern.allowed(dataset_name):
# Dataset name is rejected by pattern, report as dropped.
self.report.report_dropped(dataset_name)
continue

return
try:
# Try to load an Iceberg table. Might not contain one, this will be caught by NoSuchIcebergTableError.
table = catalog.load_table(dataset_path)
if not hasattr(thread_local, "local_catalog"):
LOGGER.debug(
f"Didn't find local_catalog in thread_local ({thread_local}), initializing new catalog"
)
thread_local.local_catalog = self.config.get_catalog()

with PerfTimer() as timer:
table = thread_local.local_catalog.load_table(dataset_path)
time_taken = timer.elapsed_seconds()
self.report.report_table_load_time(time_taken)
LOGGER.debug(
f"Loaded table: {table.identifier}, time taken: {time_taken}"
)
yield from self._create_iceberg_workunit(dataset_name, table)
except NoSuchPropertyException as e:
self.report.report_warning(
"table-property-missing",
f"Failed to create workunit for {dataset_name}. {e}",
)
LOGGER.warning(
f"NoSuchPropertyException while processing table {dataset_path}, skipping it.",
)
except NoSuchIcebergTableError as e:
self.report.report_warning(
"no-iceberg-table",
f"Failed to create workunit for {dataset_name}. {e}",
)
LOGGER.warning(
f"NoSuchIcebergTableError while processing table {dataset_path}, skipping it.",
)
except Exception as e:
self.report.report_failure("general", f"Failed to create workunit: {e}")
LOGGER.exception(
f"Exception while processing table {dataset_path}, skipping it.",
)

try:
catalog = self.config.get_catalog()
except Exception as e:
self.report.report_failure("get-catalog", f"Failed to get catalog: {e}")
return

for wu in ThreadedIteratorExecutor.process(
worker_func=_process_dataset,
args_list=[(dataset_path,) for dataset_path in self._get_datasets(catalog)],
max_workers=self.config.processing_threads,
):
yield wu

def _create_iceberg_workunit(
self, dataset_name: str, table: Table
) -> Iterable[MetadataWorkUnit]:
self.report.report_table_scanned(dataset_name)
dataset_urn: str = make_dataset_urn_with_platform_instance(
self.platform,
dataset_name,
self.config.platform_instance,
self.config.env,
)
dataset_snapshot = DatasetSnapshot(
urn=dataset_urn,
aspects=[Status(removed=False)],
)

# Dataset properties aspect.
custom_properties = table.metadata.properties.copy()
custom_properties["location"] = table.metadata.location
custom_properties["format-version"] = str(table.metadata.format_version)
custom_properties["partition-spec"] = str(self._get_partition_aspect(table))
if table.current_snapshot():
custom_properties["snapshot-id"] = str(table.current_snapshot().snapshot_id)
custom_properties["manifest-list"] = table.current_snapshot().manifest_list
dataset_properties = DatasetPropertiesClass(
name=table.name()[-1],
tags=[],
description=table.metadata.properties.get("comment", None),
customProperties=custom_properties,
)
dataset_snapshot.aspects.append(dataset_properties)
with PerfTimer() as timer:
self.report.report_table_scanned(dataset_name)
LOGGER.debug(f"Processing table {dataset_name}")
dataset_urn: str = make_dataset_urn_with_platform_instance(
self.platform,
dataset_name,
self.config.platform_instance,
self.config.env,
)
dataset_snapshot = DatasetSnapshot(
urn=dataset_urn,
aspects=[Status(removed=False)],
)

# Dataset ownership aspect.
dataset_ownership = self._get_ownership_aspect(table)
if dataset_ownership:
dataset_snapshot.aspects.append(dataset_ownership)
# Dataset properties aspect.
custom_properties = table.metadata.properties.copy()
custom_properties["location"] = table.metadata.location
custom_properties["format-version"] = str(table.metadata.format_version)
custom_properties["partition-spec"] = str(self._get_partition_aspect(table))
if table.current_snapshot():
custom_properties["snapshot-id"] = str(
table.current_snapshot().snapshot_id
)
custom_properties[
"manifest-list"
] = table.current_snapshot().manifest_list
dataset_properties = DatasetPropertiesClass(
name=table.name()[-1],
tags=[],
description=table.metadata.properties.get("comment", None),
customProperties=custom_properties,
)
dataset_snapshot.aspects.append(dataset_properties)
# Dataset ownership aspect.
dataset_ownership = self._get_ownership_aspect(table)
if dataset_ownership:
LOGGER.debug(
f"Adding ownership: {dataset_ownership} to the dataset {dataset_name}"
)
dataset_snapshot.aspects.append(dataset_ownership)

schema_metadata = self._create_schema_metadata(dataset_name, table)
dataset_snapshot.aspects.append(schema_metadata)
schema_metadata = self._create_schema_metadata(dataset_name, table)
dataset_snapshot.aspects.append(schema_metadata)

mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
mce = MetadataChangeEvent(proposedSnapshot=dataset_snapshot)
self.report.report_table_processing_time(timer.elapsed_seconds())
yield MetadataWorkUnit(id=dataset_name, mce=mce)

dpi_aspect = self._get_dataplatform_instance_aspect(dataset_urn=dataset_urn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from humanfriendly import format_timespan
from pydantic import Field, validator
from pyiceberg.catalog import Catalog, load_catalog

Expand All @@ -18,6 +19,7 @@
OperationConfig,
is_profiling_enabled,
)
from datahub.utilities.stats_collections import TopKDict, int_top_k_dict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,6 +77,9 @@ class IcebergSourceConfig(StatefulIngestionConfigBase, DatasetSourceConfigMixin)
description="Iceberg table property to look for a `CorpGroup` owner. Can only hold a single group value. If property has no value, no owner information will be emitted.",
)
profiling: IcebergProfilingConfig = IcebergProfilingConfig()
processing_threads: int = Field(
default=1, description="How many threads will be processing tables"
)

@validator("catalog", pre=True, always=True)
def handle_deprecated_catalog_format(cls, value):
Expand Down Expand Up @@ -131,17 +136,72 @@ def get_catalog(self) -> Catalog:

# Retrieve the dict associated with the one catalog entry
catalog_name, catalog_config = next(iter(self.catalog.items()))
logger.debug(
"Initializing the catalog %s with config: %s", catalog_name, catalog_config
)
return load_catalog(name=catalog_name, **catalog_config)


class TimingClass:
times: List[int]

def __init__(self):
self.times = []

def add_timing(self, t):
self.times.append(t)

def __str__(self):
if len(self.times) == 0:
return "no timings reported"
self.times.sort()
total = sum(self.times)
avg = total / len(self.times)
return str(
{
"average_time": format_timespan(avg, detailed=True, max_units=3),
"min_time": format_timespan(self.times[0], detailed=True, max_units=3),
"max_time": format_timespan(self.times[-1], detailed=True, max_units=3),
# total_time does not provide correct information in case we run in more than 1 thread
"total_time": format_timespan(total, detailed=True, max_units=3),
}
)


@dataclass
class IcebergSourceReport(StaleEntityRemovalSourceReport):
tables_scanned: int = 0
entities_profiled: int = 0
filtered: List[str] = field(default_factory=list)
load_table_timings: TimingClass = field(default_factory=TimingClass)
processing_table_timings: TimingClass = field(default_factory=TimingClass)
profiling_table_timings: TimingClass = field(default_factory=TimingClass)
listed_namespaces: int = 0
total_listed_tables: int = 0
tables_listed_per_namespace: TopKDict[str, int] = field(
default_factory=int_top_k_dict
)

def report_listed_tables_for_namespace(
self, namespace: str, no_tables: int
) -> None:
self.tables_listed_per_namespace[namespace] = no_tables
self.total_listed_tables += no_tables

def report_no_listed_namespaces(self, amount: int) -> None:
self.listed_namespaces = amount

def report_table_scanned(self, name: str) -> None:
self.tables_scanned += 1

def report_dropped(self, ent_name: str) -> None:
self.filtered.append(ent_name)

def report_table_load_time(self, t: float) -> None:
self.load_table_timings.add_timing(t)

def report_table_processing_time(self, t: float) -> None:
self.processing_table_timings.add_timing(t)

def report_table_profiling_time(self, t: float) -> None:
self.profiling_table_timings.add_timing(t)
Loading

0 comments on commit 2527f54

Please sign in to comment.