Skip to content

Commit c80d1c1

Browse files
authored
Add files via upload
1 parent 0db470f commit c80d1c1

1 file changed

Lines changed: 104 additions & 169 deletions

File tree

multioptpy/Wrapper/mapper.py

Lines changed: 104 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -120,32 +120,6 @@ def _autots_worker(config: dict, run_dir: str, workspace: str) -> list[str]:
120120
return sorted(_glob.glob(pattern, recursive=True))
121121

122122

123-
def _autots_worker_with_queue(
124-
config: dict,
125-
run_dir: str,
126-
workspace: str,
127-
result_queue: "multiprocessing.Queue[tuple]",
128-
) -> None:
129-
"""Thin wrapper around :func:`_autots_worker` for :class:`multiprocessing.Process`.
130-
131-
Sends the return value (or raised exception) to *result_queue* as a
132-
``(tag, payload)`` tuple:
133-
134-
* ``("ok", profiles)`` — on success.
135-
* ``("err", traceback_str)`` — on failure (full traceback string).
136-
137-
By not using ProcessPoolExecutor, we avoid the process-spawn/teardown
138-
overhead for the sequential path (``n_parallel=1``).
139-
"""
140-
try:
141-
profiles = _autots_worker(config, run_dir, workspace)
142-
result_queue.put(("ok", profiles))
143-
except Exception as exc: # noqa: BLE001
144-
tb_str = traceback.format_exc()
145-
if len(tb_str) > 1500:
146-
tb_str = "...(truncated)...\n" + tb_str[-1500:]
147-
result_queue.put(("err", tb_str))
148-
149123

150124
logger = logging.getLogger(__name__)
151125

@@ -2610,7 +2584,6 @@ def _run_batch_parallel(
26102584
task, run_dir, [], "FAILED", iteration,
26112585
history_log, gamma_sign, atom_i, atom_j,
26122586
)
2613-
26142587
except TimeoutError:
26152588
timed_out = True
26162589
logger.error(
@@ -2619,17 +2592,20 @@ def _run_batch_parallel(
26192592
self.worker_timeout_s,
26202593
)
26212594

2595+
# Track every future that was NOT cleanly handled in the
2596+
# as_completed loop above so we can unconditionally release
2597+
# their _in_flight keys in the cleanup pass below.
2598+
cancelled_futures: set = set()
2599+
killed_futures: set = set()
2600+
26222601
# ── Step 1: cancel not-yet-started futures ────────────────
26232602
for future in futures_map:
2624-
if future.cancel():
2603+
if future.cancel():
2604+
cancelled_futures.add(future)
26252605
task, run_dir, iteration, gamma_sign, atom_i, atom_j = futures_map[future]
26262606
self.queue.release((task.node_id, tuple(task.afir_params)))
26272607

26282608
# ── Step 2: force-kill running worker processes ───────────
2629-
# executor._processes is a {pid: multiprocessing.Process} dict
2630-
# maintained by ProcessPoolExecutor. No public API exposes
2631-
# individual worker handles, so this private attribute is the
2632-
# only reliable way to send SIGKILL to hung external binaries.
26332609
worker_procs = getattr(executor, "_processes", {})
26342610
for pid, proc in list(worker_procs.items()):
26352611
if proc.is_alive():
@@ -2638,24 +2614,35 @@ def _run_batch_parallel(
26382614
)
26392615
proc.kill()
26402616

2641-
# ── Step 3: mark all incomplete futures as TIMEOUT ────────
2617+
# ── Step 3: mark incomplete futures as TIMEOUT ────────────
2618+
# Use `future not in cancelled_futures` rather than
2619+
# `not future.done()` because proc.kill() may have already
2620+
# transitioned the future to done (failed), which would cause
2621+
# _process_single_result / release() to be silently skipped,
2622+
# leaving the key permanently stuck in _in_flight.
26422623
for future, meta in futures_map.items():
2624+
if future in cancelled_futures:
2625+
continue # release() already called in Step 1
26432626
task, run_dir, iteration, gamma_sign, atom_i, atom_j = meta
2644-
if not future.done():
2627+
killed_futures.add(future)
2628+
logger.error(
2629+
"_run_batch_parallel: worker timed out (limit=%ds): %s",
2630+
self.worker_timeout_s, run_dir,
2631+
)
2632+
try:
2633+
self._process_single_result(
2634+
task, run_dir, [], "TIMEOUT", iteration,
2635+
history_log, gamma_sign, atom_i, atom_j,
2636+
)
2637+
except Exception as exc:
26452638
logger.error(
2646-
"_run_batch_parallel: worker timed out (limit=%ds): %s",
2647-
self.worker_timeout_s, run_dir,
2639+
"_process_single_result failed after TIMEOUT (%s): %s",
2640+
run_dir, exc,
26482641
)
2649-
try:
2650-
self._process_single_result(
2651-
task, run_dir, [], "TIMEOUT", iteration,
2652-
history_log, gamma_sign, atom_i, atom_j,
2653-
)
2654-
except Exception as exc:
2655-
logger.error(
2656-
"_process_single_result failed after TIMEOUT (%s): %s",
2657-
run_dir, exc,
2658-
)
2642+
# _process_single_result calls release() internally, but
2643+
# if it raises before reaching that line, release here
2644+
# as a last-resort safety net to prevent _in_flight leak.
2645+
self.queue.release((task.node_id, tuple(task.afir_params)))
26592646

26602647
finally:
26612648
# Shut down the executor. After force-killing all workers in the
@@ -2670,23 +2657,8 @@ def _run_batch_parallel(
26702657
except Exception as exc:
26712658
logger.error("Cleanup save/log failed: %s", exc)
26722659

2673-
@staticmethod
2674-
def _kill_proc(proc) -> None:
2675-
"""Terminate a process reliably: terminate → kill."""
2676-
proc.terminate()
2677-
proc.join(timeout=10)
2678-
if proc.is_alive():
2679-
proc.kill()
2680-
proc.join(timeout=10)
26812660

2682-
@staticmethod
2683-
def _close_queue(q) -> None:
2684-
"""Safely close a multiprocessing.Queue."""
2685-
try:
2686-
q.cancel_join_thread()
2687-
q.close()
2688-
except Exception:
2689-
pass
2661+
26902662

26912663
def _process_single_result(
26922664
self,
@@ -3043,21 +3015,26 @@ def _make_autots_config(self, task: ExplorationTask, workspace: str) -> dict:
30433015
def _run_autots(self, task: ExplorationTask, run_dir: str) -> list[str]:
30443016
"""Run AutoTSWorkflow in an isolated spawned subprocess.
30453017
3046-
The previous implementation called ``os.chdir(run_dir)`` in the mapper
3047-
process itself, which was not thread-safe and corrupted the CWD when
3048-
multiple tasks ran in parallel. The CWD change is now confined to the
3049-
spawned child process, leaving the mapper's working directory unchanged.
3050-
3051-
Crash-detection polling loop:
3052-
The original ``result_q.get(timeout=None)`` would block forever
3053-
when the worker crashed before calling result_q.put(). Fix: check
3054-
``proc.is_alive()`` every poll_interval seconds and raise
3055-
RuntimeError immediately when the worker is dead. This imposes no
3056-
time limit on legitimate long-running calculations.
3057-
3058-
Deadlock prevention:
3059-
result_q.get() must be called before proc.join(). The reverse order
3060-
can deadlock when a large traceback fills the OS pipe buffer.
3018+
Uses ProcessPoolExecutor(max_workers=1, max_tasks_per_child=1), the
3019+
same mechanism as _run_batch_parallel, so that crash detection,
3020+
timeout handling, and CWD isolation are identical between the
3021+
sequential and parallel paths.
3022+
3023+
Crash detection:
3024+
ProcessPoolExecutor automatically captures any exception raised
3025+
inside _autots_worker and re-raises it via future.result(), so no
3026+
manual polling loop or Queue is required.
3027+
3028+
Timeout:
3029+
future.result(timeout=worker_timeout_s) raises concurrent.futures.
3030+
TimeoutError when the limit is exceeded. The handler cancels
3031+
pending work, force-kills the live worker via executor._processes
3032+
(same approach as _run_batch_parallel), then shuts the executor
3033+
down with wait=False to avoid blocking on a hung binary.
3034+
3035+
Return value:
3036+
_autots_worker returns a sorted list of Step-4 profile directories;
3037+
future.result() delivers that list directly to the caller.
30613038
"""
30623039
if AutoTSWorkflow is None:
30633040
raise RuntimeError(
@@ -3067,107 +3044,52 @@ def _run_autots(self, task: ExplorationTask, run_dir: str) -> list[str]:
30673044
workspace = os.path.join(run_dir, "autots_workspace")
30683045
config = self._make_autots_config(task, workspace)
30693046

3070-
# Written before the workflow starts so it survives a crash
3071-
config_used_path = os.path.join(run_dir, "config_used.json")
3047+
# Written before the workflow starts so it survives a crash.
30723048
try:
3073-
with open(config_used_path, "w", encoding="utf-8") as fh:
3049+
with open(
3050+
os.path.join(run_dir, "config_used.json"), "w", encoding="utf-8"
3051+
) as fh:
30743052
json.dump(config, fh, indent=2, default=str)
30753053
except Exception as exc:
30763054
logger.warning("_run_autots: could not write config_used.json: %s", exc)
30773055

3078-
result_q: multiprocessing.Queue = self._mp_ctx.Queue()
3079-
proc = self._mp_ctx.Process(
3080-
target=_autots_worker_with_queue,
3081-
args=(config, run_dir, workspace, result_q),
3056+
# max_tasks_per_child=1 guarantees a fresh process for this task,
3057+
# isolating os.chdir() inside _autots_worker from the parent process.
3058+
executor = ProcessPoolExecutor(
3059+
max_workers=1,
3060+
mp_context=self._mp_ctx,
3061+
max_tasks_per_child=1,
30823062
)
3063+
timed_out = False
30833064
try:
3084-
proc.start()
3085-
except OSError:
3086-
result_q.close()
3087-
raise
3088-
3089-
import queue as _queue_mod
3090-
3091-
poll_interval = 60
3092-
start_time = time.time()
3093-
3094-
# ── Step 1: crash-detection polling loop ──────────────────────────
3095-
while True:
3065+
future = executor.submit(_autots_worker, config, run_dir, workspace)
30963066
try:
3097-
tag, payload = result_q.get(timeout=poll_interval)
3098-
break
3099-
except _queue_mod.Empty:
3100-
pass
3101-
3102-
if not proc.is_alive():
3103-
# Worker exited without placing a result = crash
3104-
try:
3105-
tag, payload = result_q.get(timeout=5.0)
3106-
break
3107-
except _queue_mod.Empty:
3108-
pass
3109-
proc.join(timeout=30)
3110-
try:
3111-
result_q.cancel_join_thread()
3112-
result_q.close()
3113-
except Exception:
3114-
pass
3115-
raise RuntimeError(
3116-
f"_run_autots: worker process terminated unexpectedly without "
3117-
f"returning a result (exit_code={proc.exitcode}). "
3118-
"Possible causes: segfault, OOM kill, or unhandled C-level "
3119-
"exception in AutoTSWorkflow."
3067+
return future.result(timeout=self.worker_timeout_s)
3068+
except TimeoutError:
3069+
timed_out = True
3070+
logger.error(
3071+
"_run_autots: worker exceeded hard timeout of %ds — "
3072+
"force-killing worker process.",
3073+
self.worker_timeout_s,
31203074
)
3121-
3122-
elapsed = time.time() - start_time
3123-
3124-
# ── Optional hard timeout (disabled by default) ───────────────
3125-
# Active only when worker_timeout_s is set.
3126-
# Intended for HPC environments where the job scheduler enforces
3127-
# a wall-clock limit.
3128-
if self.worker_timeout_s is not None and elapsed >= self.worker_timeout_s:
3129-
self._kill_proc(proc)
3130-
try:
3131-
result_q.cancel_join_thread()
3132-
result_q.close()
3133-
except Exception as exc:
3134-
logger.debug("_run_autots: result_q cleanup failed: %s", exc)
3075+
future.cancel()
3076+
# No public API exposes individual worker handles; use the
3077+
# private _processes dict (same pattern as _run_batch_parallel).
3078+
worker_procs = getattr(executor, "_processes", {})
3079+
for pid, proc in list(worker_procs.items()):
3080+
if proc.is_alive():
3081+
logger.warning(
3082+
"_run_autots: force-killing worker pid=%d", pid
3083+
)
3084+
proc.kill()
31353085
raise RuntimeError(
31363086
f"_run_autots: worker exceeded hard timeout of "
3137-
f"{self.worker_timeout_s}s (elapsed={elapsed:.0f}s)."
3087+
f"{self.worker_timeout_s}s."
31383088
)
3139-
3140-
logger.debug(
3141-
"_run_autots: worker still running (elapsed=%.0fs, pid=%d).",
3142-
elapsed, proc.pid,
3143-
)
3144-
3145-
# ── Step 2: drain any remaining items ────────────────────────────
3146-
# The worker is designed to put exactly one item, but drain any
3147-
# residual items to release the child's feeder thread.
3148-
while True:
3149-
try:
3150-
result_q.get_nowait()
3151-
except (_queue_mod.Empty, OSError):
3152-
break
3153-
3154-
# ── Step 3: join after the queue is empty — deadlock-safe ─────────
3155-
try:
3156-
proc.join(timeout=120)
3157-
if proc.is_alive():
3158-
self._kill_proc(proc)
3159-
if tag == "err":
3160-
raise RuntimeError(
3161-
f"AutoTSWorkflow subprocess failed:\n{payload}"
3162-
)
3163-
return payload
31643089
finally:
3165-
try:
3166-
result_q.cancel_join_thread()
3167-
result_q.close()
3168-
except Exception:
3169-
pass
3170-
3090+
# wait=False after a force-kill avoids blocking on a hung binary;
3091+
# wait=True in the normal path ensures a clean join.
3092+
executor.shutdown(wait=not timed_out, cancel_futures=timed_out)
31713093
# ------------------------------------------------------------------ #
31723094
# Energy back-fill #
31733095
# ------------------------------------------------------------------ #
@@ -3363,27 +3285,32 @@ def _process_profile(self, profile_dir: str, run_dir: str) -> None:
33633285
ts_energy: float | None = result["ts_energy"]
33643286

33653287
# ── Step 1: parse new TS geometry ─────────────────────────────────
3288+
# Failure here must NOT skip EQ endpoint registration (Step 2).
3289+
# A missing or unreadable TS file means we cannot add a TSEdge, but
3290+
# the IRC endpoints are still valid EQ structures worth keeping.
33663291
ts_sym: list[str] = []
33673292
ts_coords: np.ndarray = np.empty((0, 3), dtype=float)
33683293
ts_xyz = result.get("ts_xyz_file", "") or ""
3294+
ts_geom_ok = False
33693295
if ts_xyz and os.path.isfile(ts_xyz):
33703296
try:
33713297
ts_sym, ts_coords = parse_xyz(ts_xyz)
3298+
ts_geom_ok = True
33723299
logger.debug(
33733300
"Parsed TS geometry: %d atoms from %s", len(ts_sym), ts_xyz
33743301
)
33753302
except Exception as exc:
33763303
logger.warning(
3377-
"_process_profile: failed to parse TS XYZ %s: %s — profile skipped.",
3304+
"_process_profile: failed to parse TS XYZ %s: %s — "
3305+
"TSEdge will not be added, but EQ endpoints will still be registered.",
33783306
ts_xyz, exc,
33793307
)
3380-
return
33813308
else:
33823309
logger.warning(
3383-
"_process_profile: TS XYZ not found (profile_dir=%s, file=%r) — profile skipped.",
3310+
"_process_profile: TS XYZ not found (profile_dir=%s, file=%r) — "
3311+
"TSEdge will not be added, but EQ endpoints will still be registered.",
33843312
profile_dir, ts_xyz,
33853313
)
3386-
return
33873314

33883315
# ── Step 2: register EQ endpoint nodes ───────────────────────────
33893316
node_id_1 = self._find_or_register_node(
@@ -3421,6 +3348,14 @@ def _process_profile(self, profile_dir: str, run_dir: str) -> None:
34213348
)
34223349

34233350
# ── Step 3: TS duplicate check (TS vs TS only) ────────────────────
3351+
if not ts_geom_ok:
3352+
logger.info(
3353+
"_process_profile: TS geometry unavailable — "
3354+
"skipping TSEdge registration (EQ%s -- EQ%s registered).",
3355+
node_id_1, node_id_2,
3356+
)
3357+
return
3358+
34243359
ts_candidates: list[TSEdge] = (
34253360
self.graph.edges_in_energy_window(ts_energy, self.energy_tolerance)
34263361
if ts_energy is not None

0 commit comments

Comments
 (0)