diff --git a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py index de14ed3cfbbb2d..232ba1347f90d8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py @@ -15,6 +15,7 @@ AllowDenyPattern, ConfigEnum, ConfigModel, + ConfigurationError, LineageConfig, ) from datahub.configuration.source_common import DatasetSourceConfigMixin @@ -480,7 +481,6 @@ def merge_schema_fields(self, schema_fields: List[SchemaField]) -> None: if self.columns: # If we already have columns, don't overwrite them. - # TODO maybe we should augment them instead? return self.columns = [ @@ -734,7 +734,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if value is not None } - inferred_schemas = self._infer_schemas_and_update_cll(all_nodes_map) + # We need to run this before filtering nodes, because the info generated + # for a filtered node may be used by an unfiltered node. + # NOTE: This method mutates the DBTNode objects directly. + self._infer_schemas_and_update_cll(all_nodes_map) nodes = self._filter_nodes(all_nodes) non_test_nodes = [ @@ -776,16 +779,16 @@ def _filter_nodes(self, all_nodes: List[DBTNode]) -> List[DBTNode]: return nodes - def _infer_schemas_and_update_cll( - self, all_nodes_map: Dict[str, DBTNode] - ) -> Dict[str, SchemaMetadata]: + def _infer_schemas_and_update_cll(self, all_nodes_map: Dict[str, DBTNode]) -> None: if not self.config.infer_dbt_schemas: - return {} + if self.config.include_column_lineage: + raise ConfigurationError( + "`infer_dbt_schemas` must be enabled to use `include_column_lineage`" + ) + return graph = self.ctx.require_graph("Using dbt with infer_dbt_schemas") - # TODO: maybe write these to the nodes directly? - schemas: Dict[str, SchemaMetadata] = {} schema_resolver = SchemaResolver( platform=self.config.target_platform, platform_instance=self.config.target_platform_instance, @@ -794,8 +797,6 @@ def _infer_schemas_and_update_cll( target_platform_urn_to_dbt_name: Dict[str, str] = {} - # TODO: only process non-test nodes here - # Iterate over the dbt nodes in topological order. # This ensures that we process upstream nodes before downstream nodes. for dbt_name in topological_sort( @@ -809,9 +810,6 @@ def _infer_schemas_and_update_cll( ): node = all_nodes_map[dbt_name] - dbt_node_urn = node.get_urn( - DBT_PLATFORM, self.config.env, self.config.platform_instance - ) if node.exists_in_target_platform: target_node_urn = node.get_urn( self.config.target_platform, @@ -833,12 +831,14 @@ def _infer_schemas_and_update_cll( # Run sql parser to infer the schema + generate column lineage. sql_result = None if node.compiled_code: + # TODO: Add CTE stops based on the upstreams list. The code as currently + # written will generate lineage to the upstreams of ephemeral nodes instead + # of the ephemeral node itself. sql_result = sqlglot_lineage( node.compiled_code, schema_resolver=schema_resolver ) # Save the column lineage. - # TODO add cte stops based on the upstreams if ( self.config.include_column_lineage and sql_result @@ -884,8 +884,6 @@ def _infer_schemas_and_update_cll( }, ) - return schemas - def create_platform_mces( self, dbt_nodes: List[DBTNode],