Skip to content

Commit

Permalink
feat(ingest/mongodb): Ingest databases as containers (#11178)
Browse files Browse the repository at this point in the history
Co-authored-by: david-leifker <[email protected]>
  • Loading branch information
asikowitz and david-leifker authored Aug 21, 2024
1 parent 4e35016 commit 30c4fa9
Show file tree
Hide file tree
Showing 3 changed files with 618 additions and 78 deletions.
180 changes: 102 additions & 78 deletions metadata-ingestion/src/datahub/ingestion/source/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
from datahub.emitter.mce_builder import (
make_data_platform_urn,
make_dataplatform_instance_urn,
make_dataset_urn_with_platform_instance,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import (
DatabaseKey,
add_dataset_to_container,
gen_containers,
)
from datahub.ingestion.api.common import PipelineContext
from datahub.ingestion.api.decorators import (
SourceCapability,
Expand All @@ -32,6 +36,7 @@
)
from datahub.ingestion.api.source import MetadataWorkUnitProcessor
from datahub.ingestion.api.workunit import MetadataWorkUnit
from datahub.ingestion.source.common.subtypes import DatasetContainerSubTypes
from datahub.ingestion.source.schema_inference.object import (
SchemaDescription,
construct_schema,
Expand Down Expand Up @@ -64,6 +69,7 @@
DataPlatformInstanceClass,
DatasetPropertiesClass,
)
from datahub.metadata.urns import DatasetUrn

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -263,6 +269,7 @@ class MongoDBSource(StatefulIngestionSourceBase):
config: MongoDBConfig
report: MongoDBSourceReport
mongo_client: MongoClient
platform: str = "mongodb"

def __init__(self, ctx: PipelineContext, config: MongoDBConfig):
super().__init__(config, ctx)
Expand All @@ -282,7 +289,9 @@ def __init__(self, ctx: PipelineContext, config: MongoDBConfig):
}

# See https://pymongo.readthedocs.io/en/stable/examples/datetimes.html#handling-out-of-range-datetimes
self.mongo_client = MongoClient(self.config.connect_uri, datetime_conversion="DATETIME_AUTO", **options) # type: ignore
self.mongo_client = MongoClient(
self.config.connect_uri, datetime_conversion="DATETIME_AUTO", **options
) # type: ignore

# This cheaply tests the connection. For details, see
# https://pymongo.readthedocs.io/en/stable/api/pymongo/mongo_client.html#pymongo.mongo_client.MongoClient
Expand Down Expand Up @@ -351,8 +360,6 @@ def get_field_type(
return SchemaFieldDataType(type=TypeClass())

def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
platform = "mongodb"

database_names: List[str] = self.mongo_client.list_database_names()

# traverse databases in sorted order so output is consistent
Expand All @@ -364,8 +371,19 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
continue

database = self.mongo_client[database_name]
collection_names: List[str] = database.list_collection_names()
database_key = DatabaseKey(
database=database_name,
platform=self.platform,
instance=self.config.platform_instance,
env=self.config.env,
)
yield from gen_containers(
container_key=database_key,
name=database_name,
sub_types=[DatasetContainerSubTypes.DATABASE],
)

collection_names: List[str] = database.list_collection_names()
# traverse collections in sorted order so output is consistent
for collection_name in sorted(collection_names):
dataset_name = f"{database_name}.{collection_name}"
Expand All @@ -374,9 +392,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.report.report_dropped(dataset_name)
continue

dataset_urn = make_dataset_urn_with_platform_instance(
platform=platform,
name=dataset_name,
dataset_urn = DatasetUrn.create_from_ids(
platform_id=self.platform,
table_name=dataset_name,
env=self.config.env,
platform_instance=self.config.platform_instance,
)
Expand All @@ -385,9 +403,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
data_platform_instance = None
if self.config.platform_instance:
data_platform_instance = DataPlatformInstanceClass(
platform=make_data_platform_urn(platform),
platform=make_data_platform_urn(self.platform),
instance=make_dataplatform_instance_urn(
platform, self.config.platform_instance
self.platform, self.config.platform_instance
),
)

Expand All @@ -397,83 +415,21 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)

schema_metadata: Optional[SchemaMetadata] = None

if self.config.enableSchemaInference:
assert self.config.maxDocumentSize is not None
collection_schema = construct_schema_pymongo(
database[collection_name],
delimiter=".",
use_random_sampling=self.config.useRandomSampling,
max_document_size=self.config.maxDocumentSize,
should_add_document_size_filter=self.should_add_document_size_filter(),
sample_size=self.config.schemaSamplingSize,
)

# initialize the schema for the collection
canonical_schema: List[SchemaField] = []
max_schema_size = self.config.maxSchemaSize
collection_schema_size = len(collection_schema.values())
collection_fields: Union[
List[SchemaDescription], ValuesView[SchemaDescription]
] = collection_schema.values()
assert max_schema_size is not None
if collection_schema_size > max_schema_size:
# downsample the schema, using frequency as the sort key
self.report.report_warning(
title="Too many schema fields",
message=f"Downsampling the collection schema because it has too many schema fields. Configured threshold is {max_schema_size}",
context=f"Schema Size: {collection_schema_size}, Collection: {dataset_urn}",
)
# Add this information to the custom properties so user can know they are looking at downsampled schema
dataset_properties.customProperties[
"schema.downsampled"
] = "True"
dataset_properties.customProperties[
"schema.totalFields"
] = f"{collection_schema_size}"

logger.debug(
f"Size of collection fields = {len(collection_fields)}"
)
# append each schema field (sort so output is consistent)
for schema_field in sorted(
collection_fields,
key=lambda x: (
-x["count"],
x["delimited_name"],
), # Negate `count` for descending order, `delimited_name` stays the same for ascending
)[0:max_schema_size]:
field = SchemaField(
fieldPath=schema_field["delimited_name"],
nativeDataType=self.get_pymongo_type_string(
schema_field["type"], dataset_name
),
type=self.get_field_type(
schema_field["type"], dataset_name
),
description=None,
nullable=schema_field["nullable"],
recursive=False,
)
canonical_schema.append(field)

# create schema metadata object for collection
schema_metadata = SchemaMetadata(
schemaName=collection_name,
platform=f"urn:li:dataPlatform:{platform}",
version=0,
hash="",
platformSchema=SchemalessClass(),
fields=canonical_schema,
schema_metadata = self._infer_schema_metadata(
collection=database[collection_name],
dataset_urn=dataset_urn,
dataset_properties=dataset_properties,
)

# TODO: use list_indexes() or index_information() to get index information
# See https://pymongo.readthedocs.io/en/stable/api/pymongo/collection.html#pymongo.collection.Collection.list_indexes.

yield from add_dataset_to_container(database_key, dataset_urn.urn())
yield from [
mcp.as_workunit()
for mcp in MetadataChangeProposalWrapper.construct_many(
entityUrn=dataset_urn,
entityUrn=dataset_urn.urn(),
aspects=[
schema_metadata,
dataset_properties,
Expand All @@ -482,6 +438,74 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)
]

def _infer_schema_metadata(
self,
collection: pymongo.collection.Collection,
dataset_urn: DatasetUrn,
dataset_properties: DatasetPropertiesClass,
) -> SchemaMetadata:
assert self.config.maxDocumentSize is not None
collection_schema = construct_schema_pymongo(
collection,
delimiter=".",
use_random_sampling=self.config.useRandomSampling,
max_document_size=self.config.maxDocumentSize,
should_add_document_size_filter=self.should_add_document_size_filter(),
sample_size=self.config.schemaSamplingSize,
)

# initialize the schema for the collection
canonical_schema: List[SchemaField] = []
max_schema_size = self.config.maxSchemaSize
collection_schema_size = len(collection_schema.values())
collection_fields: Union[
List[SchemaDescription], ValuesView[SchemaDescription]
] = collection_schema.values()
assert max_schema_size is not None
if collection_schema_size > max_schema_size:
# downsample the schema, using frequency as the sort key
self.report.report_warning(
title="Too many schema fields",
message=f"Downsampling the collection schema because it has too many schema fields. Configured threshold is {max_schema_size}",
context=f"Schema Size: {collection_schema_size}, Collection: {dataset_urn}",
)
# Add this information to the custom properties so user can know they are looking at downsampled schema
dataset_properties.customProperties["schema.downsampled"] = "True"
dataset_properties.customProperties[
"schema.totalFields"
] = f"{collection_schema_size}"

logger.debug(f"Size of collection fields = {len(collection_fields)}")
# append each schema field (sort so output is consistent)
for schema_field in sorted(
collection_fields,
key=lambda x: (
-x["count"],
x["delimited_name"],
), # Negate `count` for descending order, `delimited_name` stays the same for ascending
)[0:max_schema_size]:
field = SchemaField(
fieldPath=schema_field["delimited_name"],
nativeDataType=self.get_pymongo_type_string(
schema_field["type"], dataset_urn.name
),
type=self.get_field_type(schema_field["type"], dataset_urn.name),
description=None,
nullable=schema_field["nullable"],
recursive=False,
)
canonical_schema.append(field)

# create schema metadata object for collection
return SchemaMetadata(
schemaName=collection.name,
platform=f"urn:li:dataPlatform:{self.platform}",
version=0,
hash="",
platformSchema=SchemalessClass(),
fields=canonical_schema,
)

def is_server_version_gte_4_4(self) -> bool:
try:
server_version = self.mongo_client.server_info().get("versionArray")
Expand Down
Loading

0 comments on commit 30c4fa9

Please sign in to comment.