@@ -279,3 +279,108 @@ def assert_dict_is_subset(subset, actual):
279279 return True
280280
281281 assert assert_dict_is_subset (expected_resources , job_spec ["resources" ])
282+
283+
284+ def test_vertex_orchestrator_stop_run_graceful (mocker ) -> None :
285+ """Tests that graceful stop returns early without cancelling."""
286+ orchestrator = _get_vertex_orchestrator (
287+ location = "europe-west4" ,
288+ pipeline_root = "gs://my-bucket/pipeline" ,
289+ )
290+
291+ mock_run = mocker .MagicMock ()
292+ mock_run .run_metadata = {"orchestrator_run_id" : "test-run-id" }
293+ mock_run .orchestrator_run_id = "test-run-id"
294+
295+ mock_auth = mocker .patch .object (
296+ orchestrator , "_get_authentication" , return_value = ("creds" , "project" )
297+ )
298+
299+ orchestrator ._stop_run (mock_run , graceful = True )
300+
301+ mock_auth .assert_not_called ()
302+
303+
304+ def test_vertex_orchestrator_stop_run_forceful_static (mocker ) -> None :
305+ """Tests that forceful stop cancels a static PipelineJob."""
306+ from google .cloud import aiplatform
307+
308+ orchestrator = _get_vertex_orchestrator (
309+ location = "europe-west4" ,
310+ pipeline_root = "gs://my-bucket/pipeline" ,
311+ )
312+
313+ mock_run = mocker .MagicMock ()
314+ mock_run .run_metadata = {"orchestrator_run_id" : "test-run-id" }
315+ mock_run .orchestrator_run_id = "test-run-id"
316+ mock_run .snapshot .is_dynamic = False
317+
318+ mocker .patch .object (
319+ orchestrator , "_get_authentication" , return_value = ("creds" , "project" )
320+ )
321+
322+ mock_job = mocker .MagicMock ()
323+ mock_get = mocker .patch .object (
324+ aiplatform .PipelineJob , "get" , return_value = mock_job
325+ )
326+
327+ orchestrator ._stop_run (mock_run , graceful = False )
328+
329+ mock_get .assert_called_once_with (
330+ "test-run-id" ,
331+ project = "project" ,
332+ location = "europe-west4" ,
333+ credentials = "creds" ,
334+ )
335+ mock_job .cancel .assert_called_once ()
336+
337+
338+ def test_vertex_orchestrator_stop_run_forceful_dynamic (mocker ) -> None :
339+ """Tests that forceful stop cancels a dynamic CustomJob."""
340+ from google .cloud import aiplatform
341+
342+ orchestrator = _get_vertex_orchestrator (
343+ location = "europe-west4" ,
344+ pipeline_root = "gs://my-bucket/pipeline" ,
345+ )
346+
347+ mock_run = mocker .MagicMock ()
348+ mock_run .run_metadata = {"orchestrator_run_id" : "test-run-id" }
349+ mock_run .orchestrator_run_id = "test-run-id"
350+ mock_run .snapshot .is_dynamic = True
351+
352+ mocker .patch .object (
353+ orchestrator , "_get_authentication" , return_value = ("creds" , "project" )
354+ )
355+
356+ mock_job = mocker .MagicMock ()
357+ mock_get = mocker .patch .object (
358+ aiplatform .CustomJob , "get" , return_value = mock_job
359+ )
360+
361+ orchestrator ._stop_run (mock_run , graceful = False )
362+
363+ mock_get .assert_called_once_with (
364+ "test-run-id" ,
365+ project = "project" ,
366+ location = "europe-west4" ,
367+ credentials = "creds" ,
368+ )
369+ mock_job .cancel .assert_called_once ()
370+
371+
372+ def test_vertex_orchestrator_stop_run_missing_run_id (mocker ) -> None :
373+ """Tests that stop raises ValueError when run ID is missing."""
374+ orchestrator = _get_vertex_orchestrator (
375+ location = "europe-west4" ,
376+ pipeline_root = "gs://my-bucket/pipeline" ,
377+ )
378+
379+ mock_run = mocker .MagicMock ()
380+ mock_run .run_metadata = {}
381+ mock_run .orchestrator_run_id = None
382+
383+ with pytest .raises (
384+ ValueError , match = "Cannot find the orchestrator run ID"
385+ ):
386+ orchestrator ._stop_run (mock_run , graceful = False )
0 commit comments