From b13c1dcd94307dc78961dc514b47c75165ded1ba Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Fri, 13 Dec 2024 14:41:35 +0200 Subject: [PATCH] Fix short circuit in mapped tasks --- airflow/models/mappedoperator.py | 2 + airflow/models/skipmixin.py | 7 +- .../deps/not_previously_skipped_dep.py | 77 +++++++++---------- tests/models/test_mappedoperator.py | 57 +++++++++++++- tests/models/test_skipmixin.py | 31 +++++++- .../deps/test_not_previously_skipped_dep.py | 45 +++++++++++ 6 files changed, 176 insertions(+), 43 deletions(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 524415b848f62..7c8d562f78f7c 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -812,6 +812,8 @@ def unmap(self, resolve: None | Mapping[str, Any] | tuple[Context, Session]) -> op.is_setup = is_setup op.is_teardown = is_teardown op.on_failure_fail_dagrun = on_failure_fail_dagrun + op.downstream_task_ids = self.downstream_task_ids + op.upstream_task_ids = self.upstream_task_ids return op # After a mapped operator is serialized, there's no real way to actually diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py index ad5c5d01539cb..9bc4243f46fa2 100644 --- a/airflow/models/skipmixin.py +++ b/airflow/models/skipmixin.py @@ -121,8 +121,11 @@ def skip( raise ValueError("dag_run is required") task_ids_list = [d.task_id for d in task_list] - SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session) - session.commit() + + # The following could be applied only for non-mapped tasks + if map_index == -1: + SkipMixin._set_state_to_skipped(dag_run, task_ids_list, session) + session.commit() if task_id is not None: from airflow.models.xcom import XCom diff --git a/airflow/ti_deps/deps/not_previously_skipped_dep.py b/airflow/ti_deps/deps/not_previously_skipped_dep.py index 92dd2b373acdb..bbee2837bd044 100644 --- a/airflow/ti_deps/deps/not_previously_skipped_dep.py +++ b/airflow/ti_deps/deps/not_previously_skipped_dep.py @@ -19,6 +19,7 @@ from airflow.models.taskinstance import PAST_DEPENDS_MET from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.utils.db import LazySelectSequence class NotPreviouslySkippedDep(BaseTIDep): @@ -38,7 +39,6 @@ def _get_dep_statuses(self, ti, session, dep_context): XCOM_SKIPMIXIN_FOLLOWED, XCOM_SKIPMIXIN_KEY, XCOM_SKIPMIXIN_SKIPPED, - SkipMixin, ) from airflow.utils.state import TaskInstanceState @@ -49,46 +49,45 @@ def _get_dep_statuses(self, ti, session, dep_context): finished_task_ids = {t.task_id for t in finished_tis} for parent in upstream: - if isinstance(parent, SkipMixin): - if parent.task_id not in finished_task_ids: - # This can happen if the parent task has not yet run. - continue + if parent.task_id not in finished_task_ids: + # This can happen if the parent task has not yet run. + continue - prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session) + prev_result = ti.xcom_pull(task_ids=parent.task_id, key=XCOM_SKIPMIXIN_KEY, session=session) - if prev_result is None: - # This can happen if the parent task has not yet run. - continue + if isinstance(prev_result, LazySelectSequence): + prev_result = next(iter(prev_result)) - should_skip = False - if ( - XCOM_SKIPMIXIN_FOLLOWED in prev_result - and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED] - ): - # Skip any tasks that are not in "followed" - should_skip = True - elif ( - XCOM_SKIPMIXIN_SKIPPED in prev_result - and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED] - ): - # Skip any tasks that are in "skipped" - should_skip = True + if prev_result is None: + # This can happen if the parent task has not yet run. + continue - if should_skip: - # If the parent SkipMixin has run, and the XCom result stored indicates this - # ti should be skipped, set ti.state to SKIPPED and fail the rule so that the - # ti does not execute. - if dep_context.wait_for_past_depends_before_skipping: - past_depends_met = ti.xcom_pull( - task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False - ) - if not past_depends_met: - yield self._failing_status( - reason=("Task should be skipped but the past depends are not met") - ) - return - ti.set_state(TaskInstanceState.SKIPPED, session) - yield self._failing_status( - reason=f"Skipping because of previous XCom result from parent task {parent.task_id}" + should_skip = False + if ( + XCOM_SKIPMIXIN_FOLLOWED in prev_result + and ti.task_id not in prev_result[XCOM_SKIPMIXIN_FOLLOWED] + ): + # Skip any tasks that are not in "followed" + should_skip = True + elif XCOM_SKIPMIXIN_SKIPPED in prev_result and ti.task_id in prev_result[XCOM_SKIPMIXIN_SKIPPED]: + # Skip any tasks that are in "skipped" + should_skip = True + + if should_skip: + # If the parent SkipMixin has run, and the XCom result stored indicates this + # ti should be skipped, set ti.state to SKIPPED and fail the rule so that the + # ti does not execute. + if dep_context.wait_for_past_depends_before_skipping: + past_depends_met = ti.xcom_pull( + task_ids=ti.task_id, key=PAST_DEPENDS_MET, session=session, default=False ) - return + if not past_depends_met: + yield self._failing_status( + reason="Task should be skipped but the past depends are not met" + ) + return + ti.set_state(TaskInstanceState.SKIPPED, session) + yield self._failing_status( + reason=f"Skipping because of previous XCom result from parent task {parent.task_id}" + ) + return diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 1ee81cae1c832..f8e1afeb56dcb 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -21,7 +21,7 @@ from datetime import timedelta from typing import TYPE_CHECKING from unittest import mock -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pendulum import pytest @@ -1763,3 +1763,58 @@ def group(n: int) -> None: "group.last": {0: "success", 1: "skipped", 2: "success"}, } assert states == expected + + +class TestMappedOperator: + @pytest.fixture + def mock_operator_class(self): + return MagicMock(spec=type(BaseOperator)) + + @pytest.fixture + @patch("airflow.serialization.serialized_objects.SerializedBaseOperator") + def mapped_operator(self, _, mock_operator_class): + return MappedOperator( + operator_class=mock_operator_class, + expand_input=MagicMock(), + partial_kwargs={"task_id": "test_task"}, + task_id="test_task", + params={}, + deps=frozenset(), + operator_extra_links=[], + template_ext=[], + template_fields=[], + template_fields_renderers={}, + ui_color="", + ui_fgcolor="", + start_trigger_args=None, + start_from_trigger=False, + dag=None, + task_group=None, + start_date=None, + end_date=None, + is_empty=False, + task_module=MagicMock(), + task_type="taske_type", + operator_name="operator_name", + disallow_kwargs_override=False, + expand_input_attr="expand_input", + ) + + def test_unmap_with_resolved_kwargs(self, mapped_operator, mock_operator_class): + mapped_operator.upstream_task_ids = ["a"] + mapped_operator.downstream_task_ids = ["b"] + resolve = {"param1": "value1"} + result = mapped_operator.unmap(resolve) + assert result == mock_operator_class.return_value + assert result.task_id == "test_task" + assert result.is_setup is False + assert result.is_teardown is False + assert result.on_failure_fail_dagrun is False + assert result.upstream_task_ids == ["a"] + assert result.downstream_task_ids == ["b"] + + def test_unmap_runtime_error(self, mapped_operator): + mapped_operator.upstream_task_ids = ["a"] + mapped_operator.downstream_task_ids = ["b"] + with pytest.raises(RuntimeError): + mapped_operator.unmap(None) diff --git a/tests/models/test_skipmixin.py b/tests/models/test_skipmixin.py index 383403b9434ee..d8a6524d3267f 100644 --- a/tests/models/test_skipmixin.py +++ b/tests/models/test_skipmixin.py @@ -18,13 +18,14 @@ from __future__ import annotations import datetime -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from airflow import settings from airflow.decorators import task, task_group from airflow.exceptions import AirflowException +from airflow.models import DagRun, MappedOperator from airflow.models.skipmixin import SkipMixin from airflow.models.taskinstance import TaskInstance as TI from airflow.operators.empty import EmptyOperator @@ -53,6 +54,11 @@ def setup_method(self): def teardown_method(self): self.clean_db() + @pytest.fixture + def mock_session(self): + return Mock(spec=settings.Session) + + @pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation mode @patch("airflow.utils.timezone.utcnow") def test_skip(self, mock_now, dag_maker): session = settings.Session() @@ -81,6 +87,29 @@ def test_skip_none_tasks(self): assert not session.query.called assert not session.commit.called + def test_skip_mapped_task(self, mock_session): + SkipMixin().skip( + dag_run=MagicMock(spec=DagRun), + tasks=[MagicMock(spec=MappedOperator)], + session=mock_session, + map_index=2, + ) + mock_session.execute.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("airflow.models.skipmixin.update") + def test_skip_none_mapped_task(self, mock_update, mock_session): + SkipMixin().skip( + dag_run=MagicMock(spec=DagRun), + tasks=[MagicMock(spec=MappedOperator)], + session=mock_session, + map_index=-1, + ) + mock_session.execute.assert_called_once_with( + mock_update.return_value.where.return_value.values.return_value.execution_options.return_value + ) + mock_session.commit.assert_called_once() + @pytest.mark.parametrize( "branch_task_ids, expected_states", [ diff --git a/tests/ti_deps/deps/test_not_previously_skipped_dep.py b/tests/ti_deps/deps/test_not_previously_skipped_dep.py index 377d216030aab..493daa7c62d50 100644 --- a/tests/ti_deps/deps/test_not_previously_skipped_dep.py +++ b/tests/ti_deps/deps/test_not_previously_skipped_dep.py @@ -20,6 +20,7 @@ import pendulum import pytest +from airflow.decorators import task from airflow.models import DagRun, TaskInstance from airflow.operators.empty import EmptyOperator from airflow.providers.standard.operators.python import BranchPythonOperator @@ -84,6 +85,50 @@ def test_no_skipmixin_parent(session, dag_maker): assert ti2.state != State.SKIPPED +@pytest.mark.parametrize("condition, final_state", [(True, State.SUCCESS), (False, State.SKIPPED)]) +def test_parent_is_mapped_short_circuit(session, dag_maker, condition, final_state): + with dag_maker(session=session): + + @task + def op1(): + return [1] + + @task.short_circuit + def op2(i: int): + return condition + + @task + def op3(res: bool): + pass + + op3.expand(res=op2.expand(i=op1())) + + dr = dag_maker.create_dagrun() + + def _one_scheduling_decision_iteration() -> dict[tuple[str, int], TaskInstance]: + decision = dr.task_instance_scheduling_decisions(session=session) + return {(ti.task_id, ti.map_index): ti for ti in decision.schedulable_tis} + + tis = _one_scheduling_decision_iteration() + + tis["op1", -1].run() + assert tis["op1", -1].state == State.SUCCESS + + tis = _one_scheduling_decision_iteration() + tis["op2", 0].run() + + assert tis["op2", 0].state == State.SUCCESS + tis = _one_scheduling_decision_iteration() + + if condition: + ti3 = tis["op3", 0] + ti3.run() + else: + ti3 = dr.get_task_instance("op3", session=session) + + assert ti3.state == final_state + + def test_parent_follow_branch(session, dag_maker): """ A simple DAG with a BranchPythonOperator that follows op2. NotPreviouslySkippedDep is met.