Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix short circuit in mapped tasks #44925

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
op.is_setup = is_setup
op.is_teardown = is_teardown
op.on_failure_fail_dagrun = on_failure_fail_dagrun
op.downstream_task_ids = self.downstream_task_ids
op.upstream_task_ids = self.upstream_task_ids
Comment on lines +815 to +816
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@uranusjr following up #43883 (comment) - it seems that these attributes were missing here, so that's why PythonOperator printed misleading logs

return op

# After a mapped operator is serialized, there's no real way to actually
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ def skip(
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# The following could be applied only for non-mapped tasks
if map_index == -1:
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()
Comment on lines +125 to +128
Copy link
Contributor Author

@shahar1 shahar1 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This block could theoretically be removed, and then setting the skipped states will be done exclusively by NotPreviouslySkippedDep for both mapped and non-mapped.
Not sure regarding effects on performance, so I left it as-is for now (if we decide to remove it - some other tests will have to be adjusted).


if task_id is not None:
from airflow.models.xcom import XCom
Expand Down
79 changes: 40 additions & 39 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.db import LazySelectSequence


class NotPreviouslySkippedDep(BaseTIDep):
Expand All @@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
XCOM_SKIPMIXIN_FOLLOWED,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
from airflow.utils.state import TaskInstanceState

Expand All @@ -49,46 +49,47 @@ def _get_dep_statuses(self, ti, session, dep_context):
finished_task_ids = {t.task_id for t in finished_tis}

for parent in upstream:
if isinstance(parent, SkipMixin):
Copy link
Contributor Author

@shahar1 shahar1 Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially I though of making it:
if isinstance(parent, (SkipMixin, MappedOpeartor))

But then I saw that tests pass without this if - so I decided going with Occam's razor and simply remove it.

if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
prev_result = ti.xcom_pull(
task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session, map_indexes=ti.map_index
)

if prev_result is None:
# This can happen if the parent task has not yet run.
continue
if isinstance(prev_result, LazySelectSequence):
prev_result = next(iter(prev_result))

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True
if prev_result is None:
# This can happen if the parent task has not yet run.
continue

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]:
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
return
if not past_depends_met:
yield self._failing_status(
reason="Task should be skipped but the past depends are not met"
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
)
return
1 change: 1 addition & 0 deletions newsfragments/44925.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix short circuit operator in mapped tasks
57 changes: 56 additions & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pendulum
import pytest
Expand Down Expand Up @@ -1763,3 +1763,58 @@ def group(n: int) -> None:
"group.last": {0: "success", 1: "skipped", 2: "success"},
}
assert states == expected


class TestMappedOperator:
@pytest.fixture
def mock_operator_class(self):
return MagicMock(spec=type(BaseOperator))

@pytest.fixture
@patch("airflow.serialization.serialized_objects.SerializedBaseOperator")
def mapped_operator(self, _, mock_operator_class):
return MappedOperator(
operator_class=mock_operator_class,
expand_input=MagicMock(),
partial_kwargs={"task_id": "test_task"},
task_id="test_task",
params={},
deps=frozenset(),
operator_extra_links=[],
template_ext=[],
template_fields=[],
template_fields_renderers={},
ui_color="",
ui_fgcolor="",
start_trigger_args=None,
start_from_trigger=False,
dag=None,
task_group=None,
start_date=None,
end_date=None,
is_empty=False,
task_module=MagicMock(),
task_type="taske_type",
operator_name="operator_name",
disallow_kwargs_override=False,
expand_input_attr="expand_input",
)

def test_unmap_with_resolved_kwargs(self, mapped_operator, mock_operator_class):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
resolve = {"param1": "value1"}
result = mapped_operator.unmap(resolve)
assert result == mock_operator_class.return_value
assert result.task_id == "test_task"
assert result.is_setup is False
assert result.is_teardown is False
assert result.on_failure_fail_dagrun is False
assert result.upstream_task_ids == ["a"]
assert result.downstream_task_ids == ["b"]

def test_unmap_runtime_error(self, mapped_operator):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
with pytest.raises(RuntimeError):
mapped_operator.unmap(None)
39 changes: 33 additions & 6 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import annotations

import datetime
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.models import DagRun, MappedOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.empty import EmptyOperator
Expand Down Expand Up @@ -53,6 +54,10 @@ def setup_method(self):
def teardown_method(self):
self.clean_db()

@pytest.fixture
def mock_session(self):
return Mock(spec=settings.Session)

@patch("airflow.utils.timezone.utcnow")
def test_skip(self, mock_now, dag_maker):
session = settings.Session()
Expand All @@ -75,11 +80,33 @@ def test_skip(self, mock_now, dag_maker):
TI.end_date == now,
).one()

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, tasks=[])
assert not session.query.called
assert not session.commit.called
def test_skip_none_tasks(self, mock_session):
SkipMixin().skip(dag_run=None, tasks=[], session=mock_session)
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

def test_skip_mapped_task(self, mock_session):
SkipMixin().skip(
dag_run=MagicMock(spec=DagRun),
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=2,
)
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()

@patch("airflow.models.skipmixin.update")
def test_skip_none_mapped_task(self, mock_update, mock_session):
SkipMixin().skip(
dag_run=MagicMock(spec=DagRun),
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=-1,
)
mock_session.execute.assert_called_once_with(
mock_update.return_value.where.return_value.values.return_value.execution_options.return_value
)
mock_session.commit.assert_called_once()

@pytest.mark.parametrize(
"branch_task_ids, expected_states",
Expand Down
45 changes: 45 additions & 0 deletions tests/ti_deps/deps/test_not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pendulum
import pytest

from airflow.decorators import task
from airflow.models import DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import BranchPythonOperator
Expand Down Expand Up @@ -84,6 +85,50 @@ def test_no_skipmixin_parent(session, dag_maker):
assert ti2.state != State.SKIPPED


@pytest.mark.parametrize("condition, final_state", [(True, State.SUCCESS), (False, State.SKIPPED)])
def test_parent_is_mapped_short_circuit(session, dag_maker, condition, final_state):
with dag_maker(session=session):

@task
def op1():
return [1]

@task.short_circuit
def op2(i: int):
return condition

@task
def op3(res: bool):
pass

op3.expand(res=op2.expand(i=op1()))

dr = dag_maker.create_dagrun()

def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
decision = dr.task_instance_scheduling_decisions(session=session)
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}

tis = _one_scheduling_decision_iteration()

tis["op1", -1].run()
assert tis["op1", -1].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["op2", 0].run()

assert tis["op2", 0].state == State.SUCCESS
tis = _one_scheduling_decision_iteration()

if condition:
ti3 = tis["op3", 0]
ti3.run()
else:
ti3 = dr.get_task_instance("op3", map_index=0, session=session)

assert ti3.state == final_state


def test_parent_follow_branch(session, dag_maker):
"""
A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.
Expand Down
Loading