Skip to content

Commit

Permalink
Fix short circuit in mapped tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
shahar1 committed Dec 14, 2024
1 parent ad3d022 commit dd54028
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 48 deletions.
2 changes: 2 additions & 0 deletions airflow/models/mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) ->
op.is_setup = is_setup
op.is_teardown = is_teardown
op.on_failure_fail_dagrun = on_failure_fail_dagrun
op.downstream_task_ids = self.downstream_task_ids
op.upstream_task_ids = self.upstream_task_ids
return op

# After a mapped operator is serialized, there's no real way to actually
Expand Down
7 changes: 5 additions & 2 deletions airflow/models/skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,11 @@ def skip(
raise ValueError("dag_run is required")

task_ids_list = [d.task_id for d in task_list]
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

# The following could be applied only for non-mapped tasks
if map_index == -1:
SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session)
session.commit()

if task_id is not None:
from airflow.models.xcom import XCom
Expand Down
77 changes: 38 additions & 39 deletions airflow/ti_deps/deps/not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from airflow.models.taskinstance import PAST_DEPENDS_MET
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.utils.db import LazySelectSequence


class NotPreviouslySkippedDep(BaseTIDep):
Expand All @@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context):
XCOM_SKIPMIXIN_FOLLOWED,
XCOM_SKIPMIXIN_KEY,
XCOM_SKIPMIXIN_SKIPPED,
SkipMixin,
)
from airflow.utils.state import TaskInstanceState

Expand All @@ -49,46 +49,45 @@ def _get_dep_statuses(self, ti, session, dep_context):
finished_task_ids = {t.task_id for t in finished_tis}

for parent in upstream:
if isinstance(parent, SkipMixin):
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue
if parent.task_id not in finished_task_ids:
# This can happen if the parent task has not yet run.
continue

prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session)
prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session, map_index=ti.map_index)

if prev_result is None:
# This can happen if the parent task has not yet run.
continue
if isinstance(prev_result, LazySelectSequence):
prev_result = next(iter(prev_result))

should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif (
XCOM_SKIPMIXIN_SKIPPED in prev_result
and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]
):
# Skip any tasks that are in "skipped"
should_skip = True
if prev_result is None:
# This can happen if the parent task has not yet run.
continue

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
if not past_depends_met:
yield self._failing_status(
reason=("Task should be skipped but the past depends are not met")
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
should_skip = False
if (
XCOM_SKIPMIXIN_FOLLOWED in prev_result
and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED]
):
# Skip any tasks that are not in "followed"
should_skip = True
elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]:
# Skip any tasks that are in "skipped"
should_skip = True

if should_skip:
# If the parent SkipMixin has run, and the XCom result stored indicates this
# ti should be skipped, set ti.state to SKIPPED and fail the rule so that the
# ti does not execute.
if dep_context.wait_for_past_depends_before_skipping:
past_depends_met = ti.xcom_pull(
task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False
)
return
if not past_depends_met:
yield self._failing_status(
reason="Task should be skipped but the past depends are not met"
)
return
ti.set_state(TaskInstanceState.SKIPPED, session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from parent task {parent.task_id}"
)
return
1 change: 1 addition & 0 deletions newsfragments/44925.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix short circuit operator in mapped tasks
57 changes: 56 additions & 1 deletion tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import pendulum
import pytest
Expand Down Expand Up @@ -1763,3 +1763,58 @@ def group(n: int) -> None:
"group.last": {0: "success", 1: "skipped", 2: "success"},
}
assert states == expected


class TestMappedOperator:
@pytest.fixture
def mock_operator_class(self):
return MagicMock(spec=type(BaseOperator))

@pytest.fixture
@patch("airflow.serialization.serialized_objects.SerializedBaseOperator")
def mapped_operator(self, _, mock_operator_class):
return MappedOperator(
operator_class=mock_operator_class,
expand_input=MagicMock(),
partial_kwargs={"task_id": "test_task"},
task_id="test_task",
params={},
deps=frozenset(),
operator_extra_links=[],
template_ext=[],
template_fields=[],
template_fields_renderers={},
ui_color="",
ui_fgcolor="",
start_trigger_args=None,
start_from_trigger=False,
dag=None,
task_group=None,
start_date=None,
end_date=None,
is_empty=False,
task_module=MagicMock(),
task_type="taske_type",
operator_name="operator_name",
disallow_kwargs_override=False,
expand_input_attr="expand_input",
)

def test_unmap_with_resolved_kwargs(self, mapped_operator, mock_operator_class):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
resolve = {"param1": "value1"}
result = mapped_operator.unmap(resolve)
assert result == mock_operator_class.return_value
assert result.task_id == "test_task"
assert result.is_setup is False
assert result.is_teardown is False
assert result.on_failure_fail_dagrun is False
assert result.upstream_task_ids == ["a"]
assert result.downstream_task_ids == ["b"]

def test_unmap_runtime_error(self, mapped_operator):
mapped_operator.upstream_task_ids = ["a"]
mapped_operator.downstream_task_ids = ["b"]
with pytest.raises(RuntimeError):
mapped_operator.unmap(None)
39 changes: 33 additions & 6 deletions tests/models/test_skipmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from __future__ import annotations

import datetime
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest

from airflow import settings
from airflow.decorators import task, task_group
from airflow.exceptions import AirflowException
from airflow.models import DagRun, MappedOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import TaskInstance as TI
from airflow.operators.empty import EmptyOperator
Expand Down Expand Up @@ -53,6 +54,10 @@ def setup_method(self):
def teardown_method(self):
self.clean_db()

@pytest.fixture
def mock_session(self):
return Mock(spec=settings.Session)

@patch("airflow.utils.timezone.utcnow")
def test_skip(self, mock_now, dag_maker):
session = settings.Session()
Expand All @@ -75,11 +80,33 @@ def test_skip(self, mock_now, dag_maker):
TI.end_date == now,
).one()

def test_skip_none_tasks(self):
session = Mock()
SkipMixin().skip(dag_run=None, tasks=[])
assert not session.query.called
assert not session.commit.called
def test_skip_none_tasks(self, mock_session):
SkipMixin().skip(dag_run=None, tasks=[], session=mock_session)
mock_session.query.assert_not_called()
mock_session.commit.assert_not_called()

def test_skip_mapped_task(self, mock_session):
SkipMixin().skip(
dag_run=MagicMock(spec=DagRun),
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=2,
)
mock_session.execute.assert_not_called()
mock_session.commit.assert_not_called()

@patch("airflow.models.skipmixin.update")
def test_skip_none_mapped_task(self, mock_update, mock_session):
SkipMixin().skip(
dag_run=MagicMock(spec=DagRun),
tasks=[MagicMock(spec=MappedOperator)],
session=mock_session,
map_index=-1,
)
mock_session.execute.assert_called_once_with(
mock_update.return_value.where.return_value.values.return_value.execution_options.return_value
)
mock_session.commit.assert_called_once()

@pytest.mark.parametrize(
"branch_task_ids, expected_states",
Expand Down
45 changes: 45 additions & 0 deletions tests/ti_deps/deps/test_not_previously_skipped_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pendulum
import pytest

from airflow.decorators import task
from airflow.models import DagRun, TaskInstance
from airflow.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import BranchPythonOperator
Expand Down Expand Up @@ -84,6 +85,50 @@ def test_no_skipmixin_parent(session, dag_maker):
assert ti2.state != State.SKIPPED


@pytest.mark.parametrize("condition, final_state", [(True, State.SUCCESS), (False, State.SKIPPED)])
def test_parent_is_mapped_short_circuit(session, dag_maker, condition, final_state):
with dag_maker(session=session):

@task
def op1():
return [1]

@task.short_circuit
def op2(i: int):
return condition

@task
def op3(res: bool):
pass

op3.expand(res=op2.expand(i=op1()))

dr = dag_maker.create_dagrun()

def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]:
decision = dr.task_instance_scheduling_decisions(session=session)
return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis}

tis = _one_scheduling_decision_iteration()

tis["op1", -1].run()
assert tis["op1", -1].state == State.SUCCESS

tis = _one_scheduling_decision_iteration()
tis["op2", 0].run()

assert tis["op2", 0].state == State.SUCCESS
tis = _one_scheduling_decision_iteration()

if condition:
ti3 = tis["op3", 0]
ti3.run()
else:
ti3 = dr.get_task_instance("op3", map_idnex=0, session=session)

assert ti3.state == final_state


def test_parent_follow_branch(session, dag_maker):
"""
A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.
Expand Down

0 comments on commit dd54028

Please sign in to comment.