Skip to content

Commit

Permalink
cleanup and adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
amoghrajesh committed Jan 9, 2025
1 parent 23db568 commit d9f315a
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 5 deletions.
10 changes: 5 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,8 @@ def xcom_pull(
if task_ids is None:
# default to the current task if not provided
task_ids = self.task_id
elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable):
# Retain the ordering as per legacy
task_ids = list(task_ids)
elif isinstance(task_ids, str):
task_ids = [task_ids]
if map_indexes is None:
map_indexes = self.map_index
elif isinstance(map_indexes, Iterable):
Expand All @@ -240,8 +239,9 @@ def xcom_pull(
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
assert isinstance(msg, XComResult)
if not isinstance(msg, XComResult):
raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")

if msg.value is not None:
from airflow.models.xcom import XCom

Expand Down
34 changes: 34 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,40 @@ def execute(self, context):
)


@pytest.mark.parametrize(
"task_ids",
[
"push_task",
["push_task1", "push_task2"],
{"push_task1", "push_task2"},
],
)
def test_xcom_pull_behavior(mocked_parse, make_ti_context, mock_supervisor_comms, spy_agency, task_ids):
"""Test that a task pulls the expected XCom value if it exists."""

class CustomOperator(BaseOperator):
def execute(self, context):
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
print(f"Pulled XCom Value: {value}")

task = CustomOperator(task_id="pull_task")
ti = TaskInstance(
id=uuid7(), task_id=task.task_id, dag_id="xcom_pull_dag", run_id="test_run", try_number=1
)

what = StartupDetails(ti=ti, file="", requests_fd=0, ti_context=make_ti_context())
runtime_ti = mocked_parse(what, ti.dag_id, task)

mock_supervisor_comms.xcom_pull.return_value = "xcom_value"

spy_agency.spy_on(runtime_ti.xcom_pull, call_original=False)

run(runtime_ti, log=mock.MagicMock())

spy_agency.assert_spy_called(runtime_ti.xcom_pull)
spy_agency.assert_spy_called_with(runtime_ti.xcom_pull, task_ids=task_ids, key="key")


@pytest.mark.parametrize(
["dag_id", "task_id", "fail_with_exception"],
[
Expand Down

0 comments on commit d9f315a

Please sign in to comment.