Skip to content

Commit

Permalink
Add support for mutiple datahub emitter
Browse files Browse the repository at this point in the history
  • Loading branch information
treff7es committed Jan 20, 2025
1 parent dd01c82 commit 72738da
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, List, Optional

from airflow.configuration import conf
from pydantic.fields import Field
Expand Down Expand Up @@ -75,7 +75,7 @@ def make_emitter_hook(self) -> "DatahubGenericHook":
return DatahubGenericHook(self.datahub_conn_id)


def get_lineage_config() -> DatahubLineageConfig:
def get_lineage_configs() -> List[DatahubLineageConfig]:
"""Load the DataHub plugin config from airflow.cfg."""

enabled = conf.get("datahub", "enabled", fallback=True)
Expand Down Expand Up @@ -103,21 +103,44 @@ def get_lineage_config() -> DatahubLineageConfig:
dag_filter_pattern = AllowDenyPattern.parse_raw(
conf.get("datahub", "dag_filter_str", fallback='{"allow": [".*"]}')
)

return DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
capture_ownership_info=capture_ownership_info,
capture_ownership_as_group=capture_ownership_as_group,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
materialize_iolets=materialize_iolets,
enable_extractors=enable_extractors,
log_level=log_level,
debug_emitter=debug_emitter,
disable_openlineage_plugin=disable_openlineage_plugin,
datajob_url_link=datajob_url_link,
render_templates=render_templates,
dag_filter_pattern=dag_filter_pattern,
)
if isinstance(datahub_conn_id, List):
connection_ids = []
for conn_id in datahub_conn_id:
config = DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=conn_id,
cluster=cluster,
capture_ownership_info=capture_ownership_info,
capture_ownership_as_group=capture_ownership_as_group,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
materialize_iolets=materialize_iolets,
enable_extractors=enable_extractors,
log_level=log_level,
debug_emitter=debug_emitter,
disable_openlineage_plugin=disable_openlineage_plugin,
datajob_url_link=datajob_url_link,
render_templates=render_templates,
dag_filter_pattern=dag_filter_pattern,
)
connection_ids.append(config)
return connection_ids
return [
DatahubLineageConfig(
enabled=enabled,
datahub_conn_id=datahub_conn_id,
cluster=cluster,
capture_ownership_info=capture_ownership_info,
capture_ownership_as_group=capture_ownership_as_group,
capture_tags_info=capture_tags_info,
capture_executions=capture_executions,
materialize_iolets=materialize_iolets,
enable_extractors=enable_extractors,
log_level=log_level,
debug_emitter=debug_emitter,
disable_openlineage_plugin=disable_openlineage_plugin,
datajob_url_link=datajob_url_link,
render_templates=render_templates,
dag_filter_pattern=dag_filter_pattern,
)
]
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from datahub_airflow_plugin.datahub_listener import (
get_airflow_plugin_listener,
get_airflow_plugin_listeners,
hookimpl,
)

_listener = get_airflow_plugin_listener()
if _listener:
_listeners = get_airflow_plugin_listeners()
if _listeners:
# The run_in_thread decorator messes with pluggy's interface discovery,
# which causes the hooks to be called with no arguments and results in TypeErrors.
# This is only an issue with Pluggy <= 1.0.0.
Expand All @@ -13,22 +13,27 @@

@hookimpl
def on_task_instance_running(previous_state, task_instance, session):
assert _listener
_listener.on_task_instance_running(previous_state, task_instance, session)
assert _listeners
for listener in _listeners:
listener.on_task_instance_running(previous_state, task_instance, session)

@hookimpl
def on_task_instance_success(previous_state, task_instance, session):
assert _listener
_listener.on_task_instance_success(previous_state, task_instance, session)
assert _listeners
for listener in _listeners:
listener.on_task_instance_success(previous_state, task_instance, session)

@hookimpl
def on_task_instance_failed(previous_state, task_instance, session):
assert _listener
_listener.on_task_instance_failed(previous_state, task_instance, session)
assert _listeners
for listener in _listeners:
listener.on_task_instance_failed(previous_state, task_instance, session)

if hasattr(_listener, "on_dag_run_running"):
# We assume that all listeners have the same set of methods.
if hasattr(_listeners[0], "on_dag_run_running"):

@hookimpl
def on_dag_run_running(dag_run, msg):
assert _listener
_listener.on_dag_run_running(dag_run, msg)
assert _listeners
for listener in _listeners:
listener.on_dag_run_running(dag_run, msg)
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,9 @@ def __init__(self):

self.task_to_extractor.extractors["AthenaOperator"] = AthenaOperatorExtractor

self.task_to_extractor.extractors["BigQueryInsertJobOperator"] = (
BigQueryInsertJobOperatorExtractor
)
self.task_to_extractor.extractors[
"BigQueryInsertJobOperator"
] = BigQueryInsertJobOperatorExtractor

self._graph: Optional["DataHubGraph"] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
get_task_inlets,
get_task_outlets,
)
from datahub_airflow_plugin._config import DatahubLineageConfig, get_lineage_config
from datahub_airflow_plugin._config import DatahubLineageConfig, get_lineage_configs
from datahub_airflow_plugin._datahub_ol_adapter import translate_ol_to_datahub_urn
from datahub_airflow_plugin._extractors import SQL_PARSING_RESULT_KEY, ExtractorManager
from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator
Expand Down Expand Up @@ -82,33 +82,38 @@ def hookimpl(f: _F) -> _F: # type: ignore[misc] # noqa: F811
KILL_SWITCH_VARIABLE_NAME = "datahub_airflow_plugin_disable_listener"


def get_airflow_plugin_listener() -> Optional["DataHubListener"]:
def get_airflow_plugin_listeners() -> Optional[List["DataHubListener"]]:
# Using globals instead of functools.lru_cache to make testing easier.
global _airflow_listener_initialized
global _airflow_listener
global _airflow_listeners

if not _airflow_listener_initialized:
_airflow_listener_initialized = True

plugin_config = get_lineage_config()

if plugin_config.enabled:
_airflow_listener = DataHubListener(config=plugin_config)

telemetry.telemetry_instance.ping(
"airflow-plugin-init",
{
"airflow-version": airflow.__version__,
"datahub-airflow-plugin": "v2",
"datahub-airflow-plugin-dag-events": HAS_AIRFLOW_DAG_LISTENER_API,
"capture_executions": plugin_config.capture_executions,
"capture_tags": plugin_config.capture_tags_info,
"capture_ownership": plugin_config.capture_ownership_info,
"enable_extractors": plugin_config.enable_extractors,
"render_templates": plugin_config.render_templates,
"disable_openlineage_plugin": plugin_config.disable_openlineage_plugin,
},
)
plugin_configs = get_lineage_configs()
for plugin_config in plugin_configs:
if plugin_config.enabled:
telemetry_sent = False
conn_id = plugin_config.conn_id
Variable.get(conn_id)
_airflow_listeners.append(DataHubListener(config=plugin_config))

if not telemetry_sent:
telemetry.telemetry_instance.ping(
"airflow-plugin-init",
{
"airflow-version": airflow.__version__,
"datahub-airflow-plugin": "v2",
"datahub-airflow-plugin-dag-events": HAS_AIRFLOW_DAG_LISTENER_API,
"capture_executions": plugin_config.capture_executions,
"capture_tags": plugin_config.capture_tags_info,
"capture_ownership": plugin_config.capture_ownership_info,
"enable_extractors": plugin_config.enable_extractors,
"render_templates": plugin_config.render_templates,
"disable_openlineage_plugin": plugin_config.disable_openlineage_plugin,
},
)
telemetry_sent = True

if plugin_config.disable_openlineage_plugin:
# Deactivate the OpenLineagePlugin listener to avoid conflicts/errors.
Expand Down Expand Up @@ -286,9 +291,9 @@ def _extract_lineage(
if sql_parsing_result:
if error := sql_parsing_result.debug_info.error:
logger.info(f"SQL parsing error: {error}", exc_info=error)
datajob.properties["datahub_sql_parser_error"] = (
f"{type(error).__name__}: {error}"
)
datajob.properties[
"datahub_sql_parser_error"
] = f"{type(error).__name__}: {error}"
if not sql_parsing_result.debug_info.table_error:
input_urns.extend(sql_parsing_result.in_tables)
output_urns.extend(sql_parsing_result.out_tables)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
get_task_inlets,
get_task_outlets,
)
from datahub_airflow_plugin._config import get_lineage_config
from datahub_airflow_plugin._config import get_lineage_configs
from datahub_airflow_plugin.client.airflow_generator import AirflowGenerator
from datahub_airflow_plugin.entities import (
entities_to_datajob_urn_list,
Expand Down Expand Up @@ -44,9 +44,11 @@ def get_task_inlets_advanced(task: BaseOperator, context: Any) -> Iterable[Any]:

if task_inlets and isinstance(task_inlets, list):
inlets = []
task_ids = {o for o in task_inlets if isinstance(o, str)}.union(
op.task_id for op in task_inlets if isinstance(op, BaseOperator)
).intersection(task.get_flat_relative_ids(upstream=True))
task_ids = (
{o for o in task_inlets if isinstance(o, str)}
.union(op.task_id for op in task_inlets if isinstance(op, BaseOperator))
.intersection(task.get_flat_relative_ids(upstream=True))
)

from airflow.lineage import AUTO
from cattr import structure
Expand Down Expand Up @@ -217,7 +219,7 @@ def datahub_pre_execution(context):

def _wrap_pre_execution(pre_execution):
def custom_pre_execution(context):
config = get_lineage_config()
config = get_lineage_configs()
if config.enabled:
context["_datahub_config"] = config
datahub_pre_execution(context)
Expand All @@ -231,7 +233,7 @@ def custom_pre_execution(context):

def _wrap_on_failure_callback(on_failure_callback):
def custom_on_failure_callback(context):
config = get_lineage_config()
config = get_lineage_configs()
if config.enabled:
context["_datahub_config"] = config
try:
Expand All @@ -251,7 +253,7 @@ def custom_on_failure_callback(context):

def _wrap_on_success_callback(on_success_callback):
def custom_on_success_callback(context):
config = get_lineage_config()
config = get_lineage_configs()
if config.enabled:
context["_datahub_config"] = config
try:
Expand All @@ -271,7 +273,7 @@ def custom_on_success_callback(context):

def _wrap_on_retry_callback(on_retry_callback):
def custom_on_retry_callback(context):
config = get_lineage_config()
config = get_lineage_configs()
if config.enabled:
context["_datahub_config"] = config
try:
Expand Down Expand Up @@ -363,7 +365,7 @@ def _patch_datahub_policy():

_patch_policy(settings)

plugin_config = get_lineage_config()
plugin_config = get_lineage_configs()
telemetry.telemetry_instance.ping(
"airflow-plugin-init",
{
Expand Down

0 comments on commit 72738da

Please sign in to comment.