Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Supporting Pulling multiple XCOM values #45509

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 32 additions & 24 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def render_templates(

def xcom_pull(
self,
task_ids: str | Iterable[str] | None = None, # TODO: Simplify to a single task_id? (breaking change)
task_ids: str | Iterable[str] | None = None,
dag_id: str | None = None,
key: str = "return_value", # TODO: Make this a constant (``XCOM_RETURN_KEY``)
include_prior_dates: bool = False, # TODO: Add support for this
Expand Down Expand Up @@ -213,40 +213,48 @@ def xcom_pull(
run_id = self.run_id

if task_ids is None:
# default to the current task if not provided
task_ids = self.task_id
elif not isinstance(task_ids, str) and isinstance(task_ids, Iterable):
# TODO: Handle multiple task_ids or remove support
raise NotImplementedError("Multiple task_ids are not supported yet")

elif isinstance(task_ids, str):
task_ids = [task_ids]
if map_indexes is None:
map_indexes = self.map_index
elif isinstance(map_indexes, Iterable):
# TODO: Handle multiple map_indexes or remove support
raise NotImplementedError("Multiple map_indexes are not supported yet")

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=task_ids,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if TYPE_CHECKING:
assert isinstance(msg, XComResult)
xcoms = []
for t in task_ids:
kaxil marked this conversation as resolved.
Show resolved Hide resolved
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
key=key,
dag_id=dag_id,
task_id=t,
run_id=run_id,
map_index=map_indexes,
),
)

msg = SUPERVISOR_COMMS.get_message()
if not isinstance(msg, XComResult):
raise TypeError(f"Expected XComResult, received: {type(msg)} {msg}")

if msg.value is not None:
from airflow.models.xcom import XCom

if msg.value is not None:
from airflow.models.xcom import XCom
# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
xcom = XCom.deserialize_value(msg) # type: ignore[arg-type]
xcoms.append(xcom)
else:
xcoms.append(default)

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
return XCom.deserialize_value(msg) # type: ignore[arg-type]
return default
if len(xcoms) == 1:
return xcoms[0]
return xcoms

def xcom_push(self, key: str, value: Any):
"""
Expand Down
38 changes: 24 additions & 14 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,14 +735,20 @@ def test_get_variable_from_context(

assert var_from_context == Variable(key="test_key", value=expected_value)

def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency):
@pytest.mark.parametrize(
"task_ids",
[
"push_task",
["push_task1", "push_task2"],
{"push_task1", "push_task2"},
],
)
def test_xcom_pull(self, create_runtime_ti, mock_supervisor_comms, spy_agency, task_ids):
"""Test that a task pulls the expected XCom value if it exists."""

task_id = "push_task"

class CustomOperator(BaseOperator):
def execute(self, context):
value = context["ti"].xcom_pull(task_ids=task_id, key="key")
value = context["ti"].xcom_pull(task_ids=task_ids, key="key")
print(f"Pulled XCom Value: {value}")

task = CustomOperator(task_id="pull_task")
Expand All @@ -755,16 +761,20 @@ def execute(self, context):

run(runtime_ti, log=mock.MagicMock())

mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=None,
),
)
if isinstance(task_ids, str):
task_ids = [task_ids]

for task_id in task_ids:
mock_supervisor_comms.send_request.assert_any_call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id=task_id,
map_index=None,
),
)


class TestXComAfterTaskExecution:
Expand Down