Skip to content

Commit

Permalink
code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Oct 11, 2023
1 parent 3e5e68b commit 4d8ff1a
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions metadata-ingestion/src/datahub/ingestion/source/dbt/dbt_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AllowDenyPattern,
ConfigEnum,
ConfigModel,
ConfigurationError,
LineageConfig,
)
from datahub.configuration.source_common import DatasetSourceConfigMixin
Expand Down Expand Up @@ -480,7 +481,6 @@ def merge_schema_fields(self, schema_fields: List[SchemaField]) -> None:

if self.columns:
# If we already have columns, don't overwrite them.
# TODO maybe we should augment them instead?
return

self.columns = [
Expand Down Expand Up @@ -734,7 +734,10 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
if value is not None
}

inferred_schemas = self._infer_schemas_and_update_cll(all_nodes_map)
# We need to run this before filtering nodes, because the info generated
# for a filtered node may be used by an unfiltered node.
# NOTE: This method mutates the DBTNode objects directly.
self._infer_schemas_and_update_cll(all_nodes_map)

nodes = self._filter_nodes(all_nodes)
non_test_nodes = [
Expand Down Expand Up @@ -776,16 +779,16 @@ def _filter_nodes(self, all_nodes: List[DBTNode]) -> List[DBTNode]:

return nodes

def _infer_schemas_and_update_cll(
self, all_nodes_map: Dict[str, DBTNode]
) -> Dict[str, SchemaMetadata]:
def _infer_schemas_and_update_cll(self, all_nodes_map: Dict[str, DBTNode]) -> None:
if not self.config.infer_dbt_schemas:
return {}
if self.config.include_column_lineage:
raise ConfigurationError(
"`infer_dbt_schemas` must be enabled to use `include_column_lineage`"
)
return

graph = self.ctx.require_graph("Using dbt with infer_dbt_schemas")

# TODO: maybe write these to the nodes directly?
schemas: Dict[str, SchemaMetadata] = {}
schema_resolver = SchemaResolver(
platform=self.config.target_platform,
platform_instance=self.config.target_platform_instance,
Expand All @@ -794,8 +797,6 @@ def _infer_schemas_and_update_cll(

target_platform_urn_to_dbt_name: Dict[str, str] = {}

# TODO: only process non-test nodes here

# Iterate over the dbt nodes in topological order.
# This ensures that we process upstream nodes before downstream nodes.
for dbt_name in topological_sort(
Expand All @@ -809,9 +810,6 @@ def _infer_schemas_and_update_cll(
):
node = all_nodes_map[dbt_name]

dbt_node_urn = node.get_urn(
DBT_PLATFORM, self.config.env, self.config.platform_instance
)
if node.exists_in_target_platform:
target_node_urn = node.get_urn(
self.config.target_platform,
Expand All @@ -833,12 +831,14 @@ def _infer_schemas_and_update_cll(
# Run sql parser to infer the schema + generate column lineage.
sql_result = None
if node.compiled_code:
# TODO: Add CTE stops based on the upstreams list. The code as currently
# written will generate lineage to the upstreams of ephemeral nodes instead
# of the ephemeral node itself.
sql_result = sqlglot_lineage(
node.compiled_code, schema_resolver=schema_resolver
)

# Save the column lineage.
# TODO add cte stops based on the upstreams
if (
self.config.include_column_lineage
and sql_result
Expand Down Expand Up @@ -884,8 +884,6 @@ def _infer_schemas_and_update_cll(
},
)

return schemas

def create_platform_mces(
self,
dbt_nodes: List[DBTNode],
Expand Down

0 comments on commit 4d8ff1a

Please sign in to comment.