diff --git a/airflow-core/tests/unit/dags/test_dag_test_trigger_dagrun_operator.py b/airflow-core/tests/unit/dags/test_dag_test_trigger_dagrun_operator.py new file mode 100644 index 0000000000000..0b2dc4aad0c59 --- /dev/null +++ b/airflow-core/tests/unit/dags/test_dag_test_trigger_dagrun_operator.py @@ -0,0 +1,45 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +from airflow.models.dag import DAG +from airflow.providers.standard.operators.empty import EmptyOperator +from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator + +DEFAULT_DATE = datetime(2024, 1, 1) + +with DAG( + dag_id="test_dag_test_trigger_parent", + schedule=None, + start_date=DEFAULT_DATE, + is_paused_upon_creation=False, +) as parent_dag: + TriggerDagRunOperator( + task_id="trigger_target_dag", + trigger_dag_id="test_dag_test_trigger_target", + ) + +with DAG( + dag_id="test_dag_test_trigger_target", + schedule=None, + start_date=DEFAULT_DATE, + is_paused_upon_creation=False, +) as target_dag: + EmptyOperator(task_id="target_task") diff --git a/airflow-core/tests/unit/models/test_dag.py b/airflow-core/tests/unit/models/test_dag.py index 00a0f1a1ef2da..248029a31b1a3 100644 --- a/airflow-core/tests/unit/models/test_dag.py +++ b/airflow-core/tests/unit/models/test_dag.py @@ -1372,6 +1372,79 @@ def check_task_2(my_input): dag.test() mock_object.assert_called_with("output of first task") + @conf_vars({("core", "load_examples"): "false"}) + def test_dag_test_with_trigger_dagrun_operator(self, test_dags_bundle, session): + dag_id = "test_dag_test_trigger_parent" + target_dag_id = "test_dag_test_trigger_target" + + dagbag = DagBag(dag_folder=os.fspath(TEST_DAGS_FOLDER), include_examples=False) + dag = dagbag.dags.get(dag_id) + assert dag is not None + + # Ensure target DAG is not serialized before dag.test() runs. + assert DBDagBag().get_latest_version_of_dag(target_dag_id, session=session) is None + + dr = dag.test() + + assert dr.state == DagRunState.SUCCESS + assert session.scalar(select(DagRun).where(DagRun.dag_id == target_dag_id)) is not None + + @conf_vars({("core", "load_examples"): "false"}) + def test_dag_test_with_trigger_dagrun_operator_multi_bundle( + self, tmp_path, configure_dag_bundles, session + ): + # Setup bundle 1 (Parent) + dir1 = tmp_path / "bundle1" + dir1.mkdir() + parent_file = dir1 / "parent.py" + parent_file.write_text(""" +from airflow.sdk import DAG +from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator +from datetime import datetime + +with DAG(dag_id="parent_dag_multi", schedule=None, start_date=datetime(2024, 1, 1)): + TriggerDagRunOperator(task_id="trigger", trigger_dag_id="target_dag_multi") +""") + + # Setup bundle 2 (Target) + dir2 = tmp_path / "bundle2" + dir2.mkdir() + target_file = dir2 / "target.py" + target_file.write_text(""" +from airflow.sdk import DAG +from airflow.providers.standard.operators.empty import EmptyOperator +from datetime import datetime + +with DAG(dag_id="target_dag_multi", schedule=None, start_date=datetime(2024, 1, 1)): + EmptyOperator(task_id="target_task") +""") + + # Bundle names are sorted alphabetically. parent_bundle comes before target_bundle. + with configure_dag_bundles( + { + "aaa_parent_bundle": dir1, + "bbb_target_bundle": dir2, + } + ): + from airflow.models.dag_version import DagVersion + + # Pre-check: No versions in DB + assert DagVersion.get_version("parent_dag_multi") is None + assert DagVersion.get_version("target_dag_multi") is None + + # Load parent DAG from file manually to call test() on it + dagbag = DagBag(dag_folder=os.fspath(parent_file), include_examples=False) + parent_dag = dagbag.get_dag("parent_dag_multi") + assert parent_dag is not None + + # This works because we now sync ALL bundles even if parent is found early. + dr = parent_dag.test() + assert dr.state == DagRunState.SUCCESS + + # Verify target dag run was created + target_dr = session.scalar(select(DagRun).where(DagRun.dag_id == "target_dag_multi")) + assert target_dr is not None + def test_dag_test_with_fail_handler(self, testing_dag_bundle): mock_handle_object_1 = mock.MagicMock() mock_handle_object_2 = mock.MagicMock() diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 3df36692513a7..517416cc3f5f6 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -1286,9 +1286,6 @@ def test( if not version: from airflow.dag_processing.bundles.manager import DagBundlesManager from airflow.dag_processing.dagbag import BundleDagBag, sync_bag_to_db - from airflow.sdk.definitions._internal.dag_parsing_context import ( - _airflow_parsing_context_manager, - ) manager = DagBundlesManager() manager.sync_bundles_to_db(session=session) @@ -1298,16 +1295,12 @@ def test( for bundle in manager.get_all_dag_bundles(): if not bundle.is_initialized: bundle.initialize() - with _airflow_parsing_context_manager(dag_id=self.dag_id): - dagbag = BundleDagBag( - dag_folder=bundle.path, - bundle_path=bundle.path, - bundle_name=bundle.name, - ) - sync_bag_to_db(dagbag, bundle.name, bundle.version) - version = DagVersion.get_version(self.dag_id) - if version: - break + dagbag = BundleDagBag( + dag_folder=bundle.path, + bundle_path=bundle.path, + bundle_name=bundle.name, + ) + sync_bag_to_db(dagbag, bundle.name, bundle.version) # Preserve callback functions from original Dag since they're lost during serialization # and yes it is a hack for now! It is a tradeoff for code simplicity.