Skip to content

Commit 1181e43

Browse files
committed
Fix TriggerDagRunOperator 404 in dag.test() by syncing all bundle DAGs
1 parent 6fd0142 commit 1181e43

3 files changed

Lines changed: 68 additions & 10 deletions

File tree

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from datetime import datetime
21+
22+
from airflow.models.dag import DAG
23+
from airflow.providers.standard.operators.empty import EmptyOperator
24+
from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator
25+
26+
DEFAULT_DATE = datetime(2024, 1, 1)
27+
28+
with DAG(
29+
dag_id="test_dag_test_trigger_parent",
30+
schedule=None,
31+
start_date=DEFAULT_DATE,
32+
is_paused_upon_creation=False,
33+
) as parent_dag:
34+
TriggerDagRunOperator(
35+
task_id="trigger_target_dag",
36+
trigger_dag_id="test_dag_test_trigger_target",
37+
)
38+
39+
with DAG(
40+
dag_id="test_dag_test_trigger_target",
41+
schedule=None,
42+
start_date=DEFAULT_DATE,
43+
is_paused_upon_creation=False,
44+
) as target_dag:
45+
EmptyOperator(task_id="target_task")

airflow-core/tests/unit/models/test_dag.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,6 +1372,23 @@ def check_task_2(my_input):
13721372
dag.test()
13731373
mock_object.assert_called_with("output of first task")
13741374

1375+
@conf_vars({("core", "load_examples"): "false"})
1376+
def test_dag_test_with_trigger_dagrun_operator(self, test_dags_bundle, session):
1377+
dag_id = "test_dag_test_trigger_parent"
1378+
target_dag_id = "test_dag_test_trigger_target"
1379+
1380+
dagbag = DagBag(dag_folder=os.fspath(TEST_DAGS_FOLDER), include_examples=False)
1381+
dag = dagbag.dags.get(dag_id)
1382+
assert dag is not None
1383+
1384+
# Ensure target DAG is not serialized before dag.test() runs.
1385+
assert DBDagBag().get_latest_version_of_dag(target_dag_id, session=session) is None
1386+
1387+
dr = dag.test()
1388+
1389+
assert dr.state == DagRunState.SUCCESS
1390+
assert session.scalar(select(DagRun).where(DagRun.dag_id == target_dag_id)) is not None
1391+
13751392
def test_dag_test_with_fail_handler(self, testing_dag_bundle):
13761393
mock_handle_object_1 = mock.MagicMock()
13771394
mock_handle_object_2 = mock.MagicMock()

task-sdk/src/airflow/sdk/definitions/dag.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,9 +1286,6 @@ def test(
12861286
if not version:
12871287
from airflow.dag_processing.bundles.manager import DagBundlesManager
12881288
from airflow.dag_processing.dagbag import BundleDagBag, sync_bag_to_db
1289-
from airflow.sdk.definitions._internal.dag_parsing_context import (
1290-
_airflow_parsing_context_manager,
1291-
)
12921289

12931290
manager = DagBundlesManager()
12941291
manager.sync_bundles_to_db(session=session)
@@ -1298,13 +1295,12 @@ def test(
12981295
for bundle in manager.get_all_dag_bundles():
12991296
if not bundle.is_initialized:
13001297
bundle.initialize()
1301-
with _airflow_parsing_context_manager(dag_id=self.dag_id):
1302-
dagbag = BundleDagBag(
1303-
dag_folder=bundle.path,
1304-
bundle_path=bundle.path,
1305-
bundle_name=bundle.name,
1306-
)
1307-
sync_bag_to_db(dagbag, bundle.name, bundle.version)
1298+
dagbag = BundleDagBag(
1299+
dag_folder=bundle.path,
1300+
bundle_path=bundle.path,
1301+
bundle_name=bundle.name,
1302+
)
1303+
sync_bag_to_db(dagbag, bundle.name, bundle.version)
13081304
version = DagVersion.get_version(self.dag_id)
13091305
if version:
13101306
break

0 commit comments

Comments
 (0)