Skip to content

Commit

Permalink
fix(ingest): reduce asyncio in check_upgrade (datahub-project#11734)
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 authored Oct 30, 2024
1 parent 93f76de commit 799c452
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 58 deletions.
55 changes: 18 additions & 37 deletions metadata-ingestion/src/datahub/cli/ingest_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import csv
import json
import logging
Expand All @@ -24,6 +23,7 @@
from datahub.ingestion.run.pipeline import Pipeline
from datahub.telemetry import telemetry
from datahub.upgrade import upgrade
from datahub.utilities.perf_timer import PerfTimer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -126,7 +126,7 @@ def run(
) -> None:
"""Ingest metadata into DataHub."""

async def run_pipeline_to_completion(pipeline: Pipeline) -> int:
def run_pipeline_to_completion(pipeline: Pipeline) -> int:
logger.info("Starting metadata ingestion")
with click_spinner.spinner(disable=no_spinner or no_progress):
try:
Expand Down Expand Up @@ -166,44 +166,25 @@ async def run_pipeline_to_completion(pipeline: Pipeline) -> int:
# The default is "datahub" reporting. The extra flag will disable it.
report_to = None

async def run_ingestion_and_check_upgrade() -> int:
# TRICKY: We want to make sure that the Pipeline.create() call happens on the
# same thread as the rest of the ingestion. As such, we must initialize the
# pipeline inside the async function so that it happens on the same event
# loop, and hence the same thread.

# logger.debug(f"Using config: {pipeline_config}")
pipeline = Pipeline.create(
pipeline_config,
dry_run=dry_run,
preview_mode=preview,
preview_workunits=preview_workunits,
report_to=report_to,
no_progress=no_progress,
raw_config=raw_pipeline_config,
)
# logger.debug(f"Using config: {pipeline_config}")
pipeline = Pipeline.create(
pipeline_config,
dry_run=dry_run,
preview_mode=preview,
preview_workunits=preview_workunits,
report_to=report_to,
no_progress=no_progress,
raw_config=raw_pipeline_config,
)
with PerfTimer() as timer:
ret = run_pipeline_to_completion(pipeline)

version_stats_future = asyncio.ensure_future(
upgrade.retrieve_version_stats(pipeline.ctx.graph)
# The main ingestion has completed. If it was successful, potentially show an upgrade nudge message.
if ret == 0:
upgrade.check_upgrade_post(
main_method_runtime=timer.elapsed_seconds(), graph=pipeline.ctx.graph
)
ingestion_future = asyncio.ensure_future(run_pipeline_to_completion(pipeline))
ret = await ingestion_future

# The main ingestion has completed. If it was successful, potentially show an upgrade nudge message.
if ret == 0:
try:
# we check the other futures quickly on success
version_stats = await asyncio.wait_for(version_stats_future, 0.5)
upgrade.maybe_print_upgrade_message(version_stats=version_stats)
except Exception as e:
logger.debug(
f"timed out with {e} waiting for version stats to be computed... skipping ahead."
)

return ret

loop = asyncio.get_event_loop()
ret = loop.run_until_complete(run_ingestion_and_check_upgrade())
if ret:
sys.exit(ret)
# don't raise SystemExit if there's no error
Expand Down
70 changes: 49 additions & 21 deletions metadata-ingestion/src/datahub/upgrade/upgrade.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from datahub import __version__
from datahub.cli.config_utils import load_client_config
from datahub.ingestion.graph.client import DataHubGraph
from datahub.utilities.perf_timer import PerfTimer

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -113,7 +114,7 @@ async def get_server_config(gms_url: str, token: Optional[str]) -> dict:

async with aiohttp.ClientSession() as session:
config_endpoint = f"{gms_url}/config"
async with session.get(config_endpoint) as dh_response:
async with session.get(config_endpoint, headers=headers) as dh_response:
dh_response_json = await dh_response.json()
return dh_response_json

Expand Down Expand Up @@ -167,7 +168,28 @@ async def get_server_version_stats(
return (server_type, server_version, current_server_release_date)


async def retrieve_version_stats(
def retrieve_version_stats(
timeout: float, graph: Optional[DataHubGraph] = None
) -> Optional[DataHubVersionStats]:
version_stats: Optional[DataHubVersionStats] = None

async def _get_version_with_timeout() -> None:
# TODO: Once we're on Python 3.11+, replace with asyncio.timeout.
stats_future = _retrieve_version_stats(graph)

try:
nonlocal version_stats
version_stats = await asyncio.wait_for(stats_future, timeout=timeout)
except asyncio.TimeoutError:
log.debug("Timed out while fetching version stats")

loop = asyncio.get_event_loop()
loop.run_until_complete(_get_version_with_timeout())

return version_stats


async def _retrieve_version_stats(
server: Optional[DataHubGraph] = None,
) -> Optional[DataHubVersionStats]:
try:
Expand Down Expand Up @@ -263,7 +285,7 @@ def is_client_server_compatible(client: VersionStats, server: VersionStats) -> i
return server.version.micro - client.version.micro


def maybe_print_upgrade_message( # noqa: C901
def _maybe_print_upgrade_message( # noqa: C901
version_stats: Optional[DataHubVersionStats],
) -> None: # noqa: C901
days_before_cli_stale = 7
Expand Down Expand Up @@ -378,28 +400,34 @@ def maybe_print_upgrade_message( # noqa: C901
pass


def clip(val: float, min_val: float, max_val: float) -> float:
return max(min_val, min(val, max_val))


def check_upgrade_post(
main_method_runtime: float,
graph: Optional[DataHubGraph] = None,
) -> None:
# Guarantees: this method will not throw, and will not block for more than 3 seconds.

version_stats_timeout = clip(main_method_runtime / 10, 0.7, 3.0)
try:
version_stats = retrieve_version_stats(
timeout=version_stats_timeout, graph=graph
)
_maybe_print_upgrade_message(version_stats=version_stats)
except Exception as e:
log.debug(f"Failed to check for upgrades due to {e}")


def check_upgrade(func: Callable[..., T]) -> Callable[..., T]:
@wraps(func)
def async_wrapper(*args: Any, **kwargs: Any) -> Any:
async def run_inner_func():
return func(*args, **kwargs)

async def run_func_check_upgrade():
version_stats_future = asyncio.ensure_future(retrieve_version_stats())
main_func_future = asyncio.ensure_future(run_inner_func())
ret = await main_func_future

# the main future has returned
# we check the other futures quickly
try:
version_stats = await asyncio.wait_for(version_stats_future, 0.5)
maybe_print_upgrade_message(version_stats=version_stats)
except Exception:
log.debug("timed out waiting for version stats to be computed")
with PerfTimer() as timer:
ret = func(*args, **kwargs)

return ret
check_upgrade_post(main_method_runtime=timer.elapsed_seconds())

loop = asyncio.get_event_loop()
loop.run_until_complete(run_func_check_upgrade())
return ret

return async_wrapper

0 comments on commit 799c452

Please sign in to comment.