diff --git a/examples/tutorials/nudged_elastic_band.py b/examples/tutorials/nudged_elastic_band.py new file mode 100644 index 000000000..544104d43 --- /dev/null +++ b/examples/tutorials/nudged_elastic_band.py @@ -0,0 +1,593 @@ +"""Tutorial: run batched torch-sim Nudged Elastic Band calculations.""" +# ruff: noqa: D101, D102, D103, D107 + +# %% +# /// script +# dependencies = [ +# "ase", +# "matplotlib", +# ] +# /// + +# %% +from dataclasses import dataclass +from time import perf_counter +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase import Atoms +from ase.calculators.calculator import Calculator, all_changes +from ase.mep import NEB as ASENEB +from ase.optimize import FIRE + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import fire_init, fire_step +from torch_sim.workflows.neb import ( + as_sim_state, + assemble_path, + calculate_neb_forces, + interpolate_path, +) + +# %% [markdown] +""" +# Nudged Elastic Band + +The nudged elastic band (NEB) method finds a minimum-energy pathway between two +endpoints. The useful TorchSim pattern is not just "run one NEB", but "run many +related NEBs together": each path is an optimizer group inside one batched +`SimState`, so model evaluations and optimizer bookkeeping happen together. + +ASE is kept in this tutorial as a reference implementation. It makes the result +easier to trust, and it gives a clear baseline for reporting the speedup from +batching several independent paths in TorchSim. +""" + + +# %% +@dataclass(frozen=True) +class NEBCase: + name: str + barrier_height: float + valley_scale: float + valley_curve: float + + +def curved_double_well( + positions: torch.Tensor, + barrier_height: torch.Tensor | float, + valley_scale: torch.Tensor | float, + valley_curve: torch.Tensor | float, +) -> tuple[torch.Tensor, torch.Tensor]: + x = positions[:, 0] + y = positions[:, 1] + z = positions[:, 2] + u = x**2 - 1.0 + v = y - valley_curve * u + energy = barrier_height * u**2 + valley_scale * v**2 + z**2 + dE_dx = 4.0 * barrier_height * x * u - 4.0 * valley_scale * valley_curve * x * v + dE_dy = 2.0 * valley_scale * v + dE_dz = 2.0 * z + forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) + return energy, forces + + +# %% [markdown] +""" +The analytic surface below keeps the tutorial fast and reproducible. Every case +has minima at `x = -1` and `x = 1`, but each one has a different valley shape. +That gives us several independent NEB calculations with different barriers and +curvatures. +""" + + +# %% +class TorchBatchedDoubleWellModel(ModelInterface): + def __init__( + self, + cases: list[NEBCase], + *, + device: torch.device, + dtype: torch.dtype, + ) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = True + self.cases = cases + self.valley_scale = torch.tensor( + [case.valley_scale for case in cases], device=device, dtype=dtype + ) + self.valley_curve = torch.tensor( + [case.valley_curve for case in cases], device=device, dtype=dtype + ) + self.barrier_height = torch.tensor( + [case.barrier_height for case in cases], device=device, dtype=dtype + ) + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + case_idx = state.group_idx[state.system_idx] + per_atom_energy, forces = curved_double_well( + state.positions, + self.barrier_height[case_idx], + self.valley_scale[case_idx], + self.valley_curve[case_idx], + ) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": forces, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +class ASEDoubleWellCalculator(Calculator): + implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + + def __init__(self, case: NEBCase) -> None: + super().__init__() + self.case = case + + def calculate( + self, + atoms: Atoms | None = None, + properties: list[str] | None = None, + system_changes: list[str] = all_changes, + ) -> None: + super().calculate(atoms, properties, system_changes) + positions = torch.tensor(self.atoms.positions, dtype=torch.float64) + per_atom_energy, forces = curved_double_well( + positions, + self.case.barrier_height, + self.case.valley_scale, + self.case.valley_curve, + ) + self.results["energy"] = float(per_atom_energy.sum().item()) + self.results["forces"] = forces.detach().cpu().numpy() + + +# %% [markdown] +""" +The same potential is exposed through two interfaces. TorchSim gets a vectorized +model that chooses the right parameters from `group_idx`; ASE gets one calculator +per case. In a real workflow, the TorchSim model would usually be an ML +interatomic potential and the groups would be different reactions, defects, or +starting guesses. +""" + + +# %% +def make_state(position: tuple[float, float, float], device: torch.device) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([position], device=device, dtype=torch.float64), + masses=torch.ones(1, device=device, dtype=torch.float64), + cell=torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18], device=device), + system_idx=torch.zeros(1, device=device, dtype=torch.long), + ) + + +def make_endpoint_batch( + cases: list[NEBCase], + position: tuple[float, float, float], + device: torch.device, +) -> ts.SimState: + states = [make_state(position, device) for _ in cases] + return ts.concatenate_states(states) + + +def state_for_group(state: ts.SimState, group_idx: int) -> ts.SimState: + system_indices = torch.where(state.group_idx == group_idx)[0] + return state[system_indices] + + +def interpolate_batched_paths( + initial_state: ts.SimState, + final_state: ts.SimState, + n_images: int, +) -> ts.SimState: + paths = [ + interpolate_path( + state_for_group(initial_state, group_idx), + state_for_group(final_state, group_idx), + n_images, + ) + for group_idx in range(initial_state.n_groups) + ] + return ts.concatenate_states(paths) + + +def assemble_batched_paths( + initial_state: ts.SimState, + movable_state: ts.SimState, + final_state: ts.SimState, +) -> ts.SimState: + paths = [] + for group_idx in range(initial_state.n_groups): + path = assemble_path( + state_for_group(initial_state, group_idx), + state_for_group(movable_state, group_idx), + state_for_group(final_state, group_idx), + ) + path.group_idx = torch.zeros(path.n_systems, device=path.device, dtype=torch.long) + paths.append(path) + return ts.concatenate_states(paths) + + +def relative_energies_by_group( + state: ts.SimState, + model: ModelInterface, + n_groups: int, +) -> np.ndarray: + energies = model(state)["energy"].detach().cpu().numpy() + profiles = energies.reshape(n_groups, -1) + return profiles - profiles[:, :1] + + +def store_batched_neb_forces( + state: ts.OptimState, + neb_forces: torch.Tensor, +) -> None: + max_force_by_group = torch.zeros( + state.n_groups, device=state.device, dtype=state.dtype + ) + atom_group_idx = state.group_idx[state.system_idx] + for group_idx in range(state.n_groups): + group_mask = atom_group_idx == group_idx + max_force_by_group[group_idx] = torch.linalg.norm( + neb_forces[group_mask], dim=1 + ).max() + state.forces = neb_forces + state.neb_forces = neb_forces + state.neb_max_force_by_group = max_force_by_group + state.neb_max_force = max_force_by_group.max() + + +def calculate_batched_neb_forces( + state: ts.SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + *, + spring_constant: float, + use_climbing_image: bool, +) -> torch.Tensor: + neb_forces = torch.zeros_like(true_forces) + atom_group_idx = state.group_idx[state.system_idx] + for group_idx in range(initial_state.n_groups): + system_indices = torch.where(state.group_idx == group_idx)[0] + atom_mask = atom_group_idx == group_idx + path_state = assemble_path( + state_for_group(initial_state, group_idx), + state[system_indices], + state_for_group(final_state, group_idx), + ) + neb_forces[atom_mask] = calculate_neb_forces( + path_state, + true_forces[atom_mask], + true_energies[system_indices], + initial_energies[group_idx], + final_energies[group_idx], + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + return neb_forces + + +def batched_neb_init( + state: ts.SimState, + model: ModelInterface, + *, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + spring_constant: float, + use_climbing_image: bool, +) -> ts.OptimState: + opt_state = fire_init(state, model, fire_flavor="ase_fire") + neb_forces = calculate_batched_neb_forces( + opt_state, + opt_state.forces, + opt_state.energy, + initial_state, + final_state, + initial_energies, + final_energies, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + store_batched_neb_forces(opt_state, neb_forces) + return opt_state + + +def batched_neb_step( + state: ts.OptimState, + model: ModelInterface, + *, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + spring_constant: float, + use_climbing_image: bool, +) -> ts.OptimState: + state = fire_step(state, model, fire_flavor="ase_fire") + true_forces = state.forces.clone() + neb_forces = calculate_batched_neb_forces( + state, + true_forces, + state.energy, + initial_state, + final_state, + initial_energies, + final_energies, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + state.true_forces = true_forces + store_batched_neb_forces(state, neb_forces) + return state + + +def run_torch_sim_batched_neb( + initial_state: ts.SimState, + final_state: ts.SimState, + model: ModelInterface, + *, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[ts.SimState, list[np.ndarray], list[np.ndarray], float]: + start_time = perf_counter() + movable_images = interpolate_batched_paths(initial_state, final_state, n_images) + initial_energies = model(initial_state)["energy"] + final_energies = model(final_state)["energy"] + energy_history: list[np.ndarray] = [] + max_force_history: list[np.ndarray] = [] + + state = batched_neb_init( + movable_images, + model, + initial_state=as_sim_state(initial_state), + final_state=as_sim_state(final_state), + initial_energies=initial_energies, + final_energies=final_energies, + spring_constant=spring_constant, + use_climbing_image=True, + ) + + def record(current_state: ts.OptimState) -> None: + full_path = assemble_batched_paths(initial_state, current_state, final_state) + energy_history.append( + relative_energies_by_group(full_path, model, initial_state.n_groups) + ) + max_force_history.append( + current_state.neb_max_force_by_group.detach().cpu().numpy() + ) + + record(state) + for _ in range(max_steps): + state = batched_neb_step( + state, + model, + initial_state=as_sim_state(initial_state), + final_state=as_sim_state(final_state), + initial_energies=initial_energies, + final_energies=final_energies, + spring_constant=spring_constant, + use_climbing_image=True, + ) + record(state) + if bool((state.neb_max_force_by_group < fmax).all()): + break + + elapsed = perf_counter() - start_time + final_path = assemble_batched_paths(initial_state, state, final_state) + return final_path, energy_history, max_force_history, elapsed + + +def run_ase_neb( + initial_atoms: Atoms, + final_atoms: Atoms, + *, + case: NEBCase, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[list[Atoms], list[np.ndarray], list[float], float]: + start_time = perf_counter() + images = [initial_atoms.copy()] + images.extend(initial_atoms.copy() for _ in range(n_images)) + images.append(final_atoms.copy()) + for image in images: + image.calc = ASEDoubleWellCalculator(case) + + neb = ASENEB(images, k=spring_constant, climb=True, method="improvedtangent") + neb.interpolate(mic=True) + optimizer = FIRE(neb, logfile=None) + energy_history: list[np.ndarray] = [] + max_force_history: list[float] = [] + + def record() -> None: + energies = np.array([image.get_potential_energy() for image in images]) + energy_history.append(energies - energies[0]) + forces = neb.get_forces().reshape(-1, 3) + max_force_history.append(float(np.linalg.norm(forces, axis=1).max())) + + optimizer.attach(record, interval=1) + optimizer.run(fmax=fmax, steps=max_steps) + elapsed = perf_counter() - start_time + return images, energy_history, max_force_history, elapsed + + +# %% [markdown] +""" +Now set up a small batch. Each case is one independent NEB calculation with the +same endpoints and a different curved valley. TorchSim will optimize all of +these paths together; ASE will run the same cases sequentially. +""" + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +cases = [ + NEBCase("0.8 eV", barrier_height=0.8, valley_scale=5.0, valley_curve=0.50), + NEBCase("0.9 eV", barrier_height=0.9, valley_scale=3.0, valley_curve=0.25), + NEBCase("1.0 eV", barrier_height=1.0, valley_scale=8.0, valley_curve=0.35), + NEBCase("1.1 eV", barrier_height=1.1, valley_scale=5.5, valley_curve=-0.45), + NEBCase("1.2 eV", barrier_height=1.2, valley_scale=7.0, valley_curve=0.70), + NEBCase("1.3 eV", barrier_height=1.3, valley_scale=4.0, valley_curve=-0.75), + NEBCase("1.4 eV", barrier_height=1.4, valley_scale=2.5, valley_curve=0.60), + NEBCase("1.5 eV", barrier_height=1.5, valley_scale=9.0, valley_curve=-0.25), +] +n_images = 7 +spring_constant = 0.1 +max_steps = 200 +fmax = 0.03 + +initial_state = make_endpoint_batch(cases, (-1.0, 0.0, 0.0), device) +final_state = make_endpoint_batch(cases, (1.0, 0.0, 0.0), device) +model = TorchBatchedDoubleWellModel(cases, device=device, dtype=torch.float64) + +initial_atoms = Atoms("Ar", positions=[[-1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) +final_atoms = Atoms("Ar", positions=[[1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) + + +# %% [markdown] +""" +The TorchSim call below is the main workflow. The movable images for every case +are stored in one state, and FIRE uses `group_idx` so each NEB keeps its own +optimizer state while still sharing one batched model call per step. +""" + +# %% +torch_path, torch_energy_history, torch_fmax, torch_elapsed = run_torch_sim_batched_neb( + initial_state, + final_state, + model, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, +) + + +# %% [markdown] +""" +ASE is the validation baseline. It does not batch these independent NEBs here, +so we run the same cases one after another and compare final profiles. +""" + +# %% +ase_energy_history: list[list[np.ndarray]] = [] +ase_fmax: list[list[float]] = [] +ase_elapsed_by_case: list[float] = [] +ase_start = perf_counter() +for case in cases: + _, case_energy_history, case_fmax, case_elapsed = run_ase_neb( + initial_atoms, + final_atoms, + case=case, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, + ) + ase_energy_history.append(case_energy_history) + ase_fmax.append(case_fmax) + ase_elapsed_by_case.append(case_elapsed) +ase_elapsed = perf_counter() - ase_start + + +# %% [markdown] +""" +Finally, compare accuracy and runtime. The per-case barrier differences should +be tiny; once that is true, the runtime comparison shows why batching many NEBs +together is the more relevant TorchSim workflow. +""" + +# %% +n_cases = len(cases) +reaction_coordinate = np.linspace(0.0, 1.0, n_images + 2) +torch_final = relative_energies_by_group(torch_path, model, n_cases) +ase_final = np.stack([history[-1] for history in ase_energy_history]) +barrier_difference = np.abs(torch_final.max(axis=1) - ase_final.max(axis=1)) +max_barrier_difference = barrier_difference.max() +speedup = ase_elapsed / torch_elapsed if torch_elapsed > 0 else float("inf") + +print("Final barrier comparison") +print("case torch-sim ASE abs diff") +for case, torch_profile, ase_profile, diff in zip( + cases, torch_final, ase_final, barrier_difference, strict=True +): + print( + f"{case.name:10s} {torch_profile.max(): .8f} {ase_profile.max(): .8f} " + f"{diff:.3e}" + ) +print(f"Max barrier difference: {max_barrier_difference:.3e} eV") +print(f"torch-sim batched runtime: {torch_elapsed:.3f} s") +print(f"ASE sequential runtime: {ase_elapsed:.3f} s") +print(f"torch-sim speedup vs sequential ASE: {speedup:.2f}x") + + +# %% +fig, axes = plt.subplots(2, 2, figsize=(11, 8)) +colors = plt.cm.viridis(np.linspace(0.05, 0.95, n_cases)) + +for idx, (case, color) in enumerate(zip(cases, colors, strict=True)): + axes[0, 0].plot( + reaction_coordinate, + torch_final[idx], + color=color, + linewidth=2, + label=case.name, + ) + axes[0, 0].plot(reaction_coordinate, ase_final[idx], "o", color=color, markersize=3) +axes[0, 0].set_ylabel("Relative energy") +axes[0, 0].set_title("Final profiles: lines are torch-sim, dots are ASE") +axes[0, 0].legend(fontsize=8) + +axes[0, 1].bar([case.name for case in cases], barrier_difference) +axes[0, 1].set_ylabel("Barrier |torch-sim - ASE|") +axes[0, 1].set_yscale("log") +axes[0, 1].tick_params(axis="x", rotation=45) +axes[0, 1].set_title("Validation error by NEB case") + +torch_force_history = np.stack(torch_fmax) +axes[1, 0].plot( + torch_force_history.max(axis=1), + color="tab:blue", + linestyle="--", + linewidth=2, + label="torch-sim batch max", +) +for case, case_fmax in zip(cases, ase_fmax, strict=True): + axes[1, 0].plot(case_fmax, alpha=0.35, linewidth=1, label=f"ASE {case.name}") +axes[1, 0].axhline(fmax, color="k", linestyle=":", label="fmax") +axes[1, 0].set_xlabel("Optimization step") +axes[1, 0].set_ylabel("Max NEB force") +axes[1, 0].set_yscale("log") +axes[1, 0].set_title("Convergence") +axes[1, 0].legend(fontsize=7) + +axes[1, 1].bar(["torch-sim batched", "ASE sequential"], [torch_elapsed, ase_elapsed]) +axes[1, 1].set_ylabel("Runtime (s)") +axes[1, 1].set_title(f"Batching speedup: {speedup:.2f}x") + +fig.tight_layout() +fig.savefig("neb_ase_torchsim_comparison.png", dpi=200) +print("Saved comparison plot to neb_ase_torchsim_comparison.png") diff --git a/pyproject.toml b/pyproject.toml index 9f473cd5d..c7a14f939 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -235,6 +235,7 @@ include = ["docs/**/*.py", "docs/**/*.ipynb", "examples/**/*.py"] [tool.ty.overrides.rules] invalid-argument-type = "ignore" invalid-assignment = "ignore" +invalid-attribute-override = "ignore" not-iterable = "ignore" not-subscriptable = "ignore" unresolved-attribute = "ignore" diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a05e06089..9b802c9ec 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -399,6 +399,95 @@ def test_binning_auto_batcher_restore_order_with_split_states( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) +def test_binning_auto_batcher_keeps_group_together( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """A multi-system group is treated as one unit and stays in a single batch.""" + grouped = ts.concatenate_states([si_sim_state, fe_supercell_sim_state]) + grouped.group_idx = torch.zeros( + grouped.n_systems, device=grouped.device, dtype=torch.long + ) + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=float(grouped.n_atoms), + ) + batcher.load_states(grouped) + + # one group => one unit whose scaler is the sum over its systems + assert len(batcher.memory_scalers) == 1 + assert batcher.memory_scalers[0] == grouped.n_atoms + + batches = [batch for batch, _ in batcher] + assert len(batches) == 1 + assert batches[0].n_systems == grouped.n_systems + assert batches[0].n_groups == 1 + + restored = batcher.restore_original_order(batches) + assert len(restored) == 1 + assert restored[0].n_systems == grouped.n_systems + + +def test_binning_auto_batcher_packs_multiple_groups( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Multiple groups pack into one bin when memory allows (no throttling).""" + grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state]) + grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long) + group0 = 2 * si_sim_state.n_atoms + group1 = fe_supercell_sim_state.n_atoms + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=float(group0 + group1), + ) + batcher.load_states(grouped) + + assert batcher.memory_scalers == [group0, group1] + + batches = [batch for batch, _ in batcher] + assert len(batches) == 1 + assert batches[0].n_systems == 3 + assert batches[0].n_groups == 2 + + restored = batcher.restore_original_order(batches) + assert len(restored) == 2 + assert restored[0].n_systems == 2 + assert restored[1].n_systems == 1 + + +def test_binning_auto_batcher_does_not_split_group( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """A group never spans bins; tight memory packs one group per batch.""" + grouped = ts.concatenate_states([si_sim_state, si_sim_state, fe_supercell_sim_state]) + grouped.group_idx = torch.tensor([0, 0, 1], device=grouped.device, dtype=torch.long) + group0 = 2 * si_sim_state.n_atoms + group1 = fe_supercell_sim_state.n_atoms + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=float(max(group0, group1)), + ) + batcher.load_states(grouped) + + batches = [batch for batch, _ in batcher] + assert len(batches) == 2 + assert sorted(batch.n_systems for batch in batches) == [1, 2] + + multi = next(batch for batch in batches if batch.n_systems == 2) + assert multi.n_groups == 1 + + def test_in_flight_max_metric_too_small( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, @@ -715,6 +804,35 @@ def test_in_flight_max_iterations( assert batcher.iteration_count[idx] == max_iterations +def test_in_flight_max_iterations_completes_whole_group( + si_double_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + grouped_state = si_double_sim_state.clone() + grouped_state.group_idx = torch.zeros( + grouped_state.n_systems, device=grouped_state.device, dtype=torch.long + ) + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=800.0, + max_iterations=1, + ) + batcher.load_states(grouped_state) + + state, [] = batcher.next_batch(None, None) + assert state is not None + assert state.n_systems == grouped_state.n_systems + assert state.n_groups == 1 + + convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool) + next_state, completed_states = batcher.next_batch(state, convergence_tensor) + + assert next_state is None + assert len(completed_states) == 1 + assert completed_states[0].n_systems == grouped_state.n_systems + + @pytest.mark.parametrize( "num_steps_per_batch", [ diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 3f419f688..78b37e151 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -162,6 +162,36 @@ def test_fire_optimization( ) +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_fire_uses_group_scoped_adaptive_state( + ar_double_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor +) -> None: + ar_double_sim_state.group_idx = torch.zeros( + ar_double_sim_state.n_systems, + device=ar_double_sim_state.device, + dtype=torch.int64, + ) + + state = ts.fire_init( + ar_double_sim_state, + lj_model, + fire_flavor=fire_flavor, + dt_start=0.1, + alpha_start=0.1, + ) + + assert state.n_groups == 1 + assert state.dt.shape == (1,) + assert state.alpha.shape == (1,) + assert state.n_pos.shape == (1,) + + updated = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + + assert updated.dt.shape == (1,) + assert updated.alpha.shape == (1,) + assert updated.n_pos.shape == (1,) + + def test_bfgs_optimization( ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: diff --git a/tests/test_state.py b/tests/test_state.py index 7c73a6498..33f9c306d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -37,11 +37,36 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell"} + assert set(per_system_attrs) == {"cell", "group_idx"} + per_group_attrs = dict(get_attrs_for_scope(si_sim_state, "per-group")) + assert set(per_group_attrs) == set() global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) assert set(global_attrs) == {"pbc", "_rng"} +def test_group_idx_defaults_to_one_group_per_system( + si_double_sim_state: SimState, +) -> None: + assert torch.equal( + si_double_sim_state.group_idx, + torch.arange(si_double_sim_state.n_systems, device=si_double_sim_state.device), + ) + assert si_double_sim_state.n_groups == si_double_sim_state.n_systems + + +def test_slice_remaps_group_idx(si_double_sim_state: SimState) -> None: + state = si_double_sim_state.clone() + state.group_idx = torch.zeros(state.n_systems, device=state.device, dtype=torch.int64) + + sliced = _slice_state(state, [1, 0]) + + assert torch.equal( + sliced.group_idx, + torch.zeros(sliced.n_systems, device=sliced.device, dtype=torch.int64), + ) + assert sliced.n_groups == 1 + + def test_all_attributes_must_be_specified_in_scopes() -> None: """Test that an error is raised when we forget to specify the scope for an attribute in a child SimState class.""" @@ -308,6 +333,68 @@ def test_split_many_states( assert len(states) == 3 +def test_split_by_group_matches_systems_for_default_groups( + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, +) -> None: + """With default single-system groups, split(by='group') == split(by='system').""" + concatenated = ts.concatenate_states( + [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] + ) + by_system = concatenated.split_systems() + by_group = concatenated.split_groups() + assert len(by_system) == len(by_group) == 3 + for sys_state, group_state in zip(by_system, by_group, strict=True): + assert torch.allclose(sys_state.positions, group_state.positions) + assert torch.allclose(sys_state.system_idx, group_state.system_idx) + assert group_state.n_groups == 1 + + +def test_split_by_group_keeps_multi_system_groups_together( + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, +) -> None: + """split(by='group') returns one state per group with systems intact.""" + concatenated = ts.concatenate_states( + [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] + ) + # group 0 = {si, ar}, group 1 = {fe} + concatenated.group_idx = torch.tensor( + [0, 0, 1], device=concatenated.device, dtype=torch.long + ) + groups = concatenated.split(by="group") + + assert len(groups) == 2 + assert groups[0].n_systems == 2 + assert groups[1].n_systems == 1 + assert groups[0].n_groups == groups[1].n_groups == 1 + # first group holds si + ar atoms with locally re-based system_idx + assert groups[0].n_atoms == si_sim_state.n_atoms + ar_supercell_sim_state.n_atoms + assert torch.equal( + torch.unique(groups[0].system_idx), + torch.tensor([0, 1], device=concatenated.device), + ) + assert torch.allclose(groups[1].positions, fe_supercell_sim_state.positions) + + +def test_split_by_group_rejects_non_contiguous_groups( + si_sim_state: SimState, + ar_supercell_sim_state: SimState, + fe_supercell_sim_state: SimState, +) -> None: + """Non-decreasing group_idx is required for splitting by group.""" + concatenated = ts.concatenate_states( + [si_sim_state, ar_supercell_sim_state, fe_supercell_sim_state] + ) + concatenated.group_idx = torch.tensor( + [0, 1, 0], device=concatenated.device, dtype=torch.long + ) + with pytest.raises(ValueError, match="non-decreasing group_idx"): + concatenated.split_groups() + + def test_pop_states( si_sim_state: SimState, ar_supercell_sim_state: SimState, diff --git a/tests/workflows/test_neb.py b/tests/workflows/test_neb.py new file mode 100644 index 000000000..bb3a4277f --- /dev/null +++ b/tests/workflows/test_neb.py @@ -0,0 +1,201 @@ +import numpy as np +import torch +from ase import Atoms +from ase.mep import NEB as ASENEB +from ase.mep.neb import ImprovedTangentMethod, NEBState + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import ModelInterface +from torch_sim.workflows.neb import ( + NEB, + assemble_path, + calculate_neb_forces, + interpolate_path, +) + + +class HarmonicModel(ModelInterface): + def __init__(self) -> None: + super().__init__() + self._device = DEVICE + self._dtype = DTYPE + self._compute_forces = True + self._compute_stress = True + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + per_atom_energy = 0.5 * (state.positions**2).sum(dim=1) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": -state.positions, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +class GroupIndexedModel(HarmonicModel): + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + if state.group_idx.max() >= 1: + raise AssertionError("NEB endpoint/path assembly should preserve one group.") + return super().forward(state, **kwargs) + + +def _single_atom_state(position: float) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([[position, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + masses=torch.ones(1, device=DEVICE, dtype=DTYPE), + cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18], device=DEVICE), + system_idx=torch.zeros(1, device=DEVICE, dtype=torch.long), + ) + + +def test_assemble_path_preserves_one_optimizer_group() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + movable = interpolate_path(initial, final, n_images=3) + + path = assemble_path(initial, movable, final) + + assert path.n_systems == 5 + assert path.n_groups == 1 + assert torch.equal(path.group_idx, torch.zeros(5, device=DEVICE, dtype=torch.long)) + + +def test_interpolate_path_uses_movable_images_only() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + + path = interpolate_path(initial, final, n_images=3) + + assert path.n_systems == 3 + assert path.n_groups == 1 + assert torch.equal(path.group_idx, torch.zeros(3, device=DEVICE, dtype=torch.long)) + assert torch.allclose( + path.positions[:, 0], + torch.tensor([0.25, 0.5, 0.75], device=DEVICE, dtype=DTYPE), + ) + + +def test_calculate_neb_forces_matches_ase_step0_components() -> None: + n_images = 5 + n_atoms = 2 + spring_constant = 0.1 + positions = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 0.8, 0.0]], + [[0.2, 0.1, 0.0], [0.1, 0.9, 0.0]], + [[0.5, 0.25, 0.0], [0.2, 1.0, 0.1]], + [[0.8, 0.35, 0.0], [0.25, 1.05, 0.2]], + [[1.0, 0.5, 0.0], [0.4, 1.2, 0.3]], + ], + device=DEVICE, + dtype=DTYPE, + ) + energies = torch.tensor([0.0, 0.2, 0.7, 0.4, 0.1], device=DEVICE, dtype=DTYPE) + true_forces = torch.tensor( + [ + [[0.1, -0.2, 0.0], [0.0, 0.3, -0.1]], + [[-0.2, 0.1, 0.2], [0.2, -0.1, 0.0]], + [[0.3, 0.0, -0.1], [-0.1, 0.2, 0.1]], + ], + device=DEVICE, + dtype=DTYPE, + ) + path_state = ts.SimState( + positions=positions.reshape(-1, 3), + masses=torch.ones(n_images * n_atoms, device=DEVICE, dtype=DTYPE), + cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(n_images, 1, 1) + * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18, 18], device=DEVICE).repeat(n_images), + system_idx=torch.repeat_interleave( + torch.arange(n_images, device=DEVICE), repeats=n_atoms + ), + ) + + torch_forces = calculate_neb_forces( + path_state, + true_forces.reshape(-1, 3), + energies[1:-1], + energies[0], + energies[-1], + spring_constant=spring_constant, + use_climbing_image=True, + ).reshape(n_images - 2, n_atoms, 3) + + ase_images = [ + Atoms( + "Ar2", + positions=image_positions.detach().cpu().numpy(), + cell=np.eye(3) * 10.0, + pbc=False, + ) + for image_positions in positions + ] + ase_neb = ASENEB(ase_images, k=spring_constant, climb=True, method="improvedtangent") + ase_state = NEBState(ase_neb, ase_neb.images, energies.detach().cpu().numpy()) + tangent_method = ImprovedTangentMethod(ase_neb) + ase_forces = [] + true_forces_np = true_forces.detach().cpu().numpy() + for image_index in range(1, n_images - 1): + spring1 = ase_state.spring(image_index - 1) + spring2 = ase_state.spring(image_index) + tangent = tangent_method.get_tangent(ase_state, spring1, spring2, image_index) + tangent_norm = np.linalg.norm(tangent) + if tangent_norm > 1e-15: + tangent = tangent / tangent_norm + force = true_forces_np[image_index - 1] + force_dot_tangent = np.vdot(force, tangent) + if ase_neb.climb and image_index == ase_state.imax: + ase_forces.append(force - 2 * force_dot_tangent * tangent) + else: + spring_force = (spring2.nt * spring2.k - spring1.nt * spring1.k) * tangent + ase_forces.append(force - force_dot_tangent * tangent + spring_force) + + assert torch.allclose( + torch_forces, + torch.tensor(np.array(ase_forces), device=DEVICE, dtype=DTYPE), + atol=1e-12, + rtol=1e-12, + ) + + +def test_neb_run_uses_single_chain_optimize_without_moving_endpoints() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + neb = NEB( + model=HarmonicModel(), + n_images=1, + optimizer_type="gd", + optimizer_params={"pos_lr": 0.1}, + ) + + result = neb.run(initial, final, max_steps=3, fmax=1e-12) + + assert result.n_systems == 3 + assert torch.allclose( + result.positions[:, 0], + torch.tensor([0.0, 0.5, 1.0], device=DEVICE, dtype=DTYPE), + ) + assert result.n_groups == 1 + + +def test_neb_run_does_not_offset_endpoint_groups() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + neb = NEB( + model=GroupIndexedModel(), + n_images=1, + optimizer_type="gd", + optimizer_params={"pos_lr": 0.1}, + ) + + result = neb.run(initial, final, max_steps=1, fmax=1e-12) + + assert result.n_groups == 1 diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index d43241664..2983e2350 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -23,7 +23,7 @@ import logging from collections.abc import Callable, Iterator, Sequence from itertools import chain -from typing import Any, get_args +from typing import Any, cast, get_args import torch @@ -386,6 +386,36 @@ def calculate_memory_scalers( ) +def _unit_memory_scaler( + state: SimState, + memory_scales_with: MemoryScaling, + cutoff: float, +) -> float: + """Estimate memory for one autobatching unit, which may contain many systems.""" + return float(sum(calculate_memory_scalers(state, memory_scales_with, cutoff))) + + +def _group_memory_scalers( + state: SimState, + memory_scales_with: MemoryScaling, + cutoff: float, +) -> list[float]: + """Estimate memory for each optimizer group without materializing per-group states. + + Computes per-system memory scalers with the already-vectorized + ``calculate_memory_scalers`` and segment-sums them by ``group_idx`` so that + each group's scaler is the sum over its constituent systems. + """ + per_system = torch.tensor( + calculate_memory_scalers(state, memory_scales_with, cutoff), + dtype=torch.float64, + device=state.device, + ) + return torch.bincount( + state.group_idx, weights=per_system, minlength=state.n_groups + ).tolist() + + def estimate_max_memory_scaler( states: SimState | Sequence[SimState], model: ModelInterface, @@ -430,8 +460,17 @@ def estimate_max_memory_scaler( min_metric = metric_values.min() max_metric = metric_values.max() - min_state = states[int(metric_values.argmin())] - max_state = states[int(metric_values.argmax())] + min_idx = int(metric_values.argmin()) + max_idx = int(metric_values.argmax()) + + # metric_values are per-group, so only materialize the two extreme groups rather + # than splitting every group out of the batched state. + if isinstance(states, SimState): + min_state = states[torch.where(states.group_idx == min_idx)[0]] + max_state = states[torch.where(states.group_idx == max_idx)[0]] + else: + min_state = states[min_idx] + max_state = states[max_idx] print( # noqa: T201 "Model Memory Estimation: Estimating memory from worst case of " @@ -468,7 +507,8 @@ class BinningAutoBatcher[T: SimState]: index_to_scaler (dict): Mapping from state index to its scaling metric. index_bins (list[list[int]]): Groups of state indices that can be batched together. - batched_states (list[list[SimState]]): Grouped states ready for batching. + batched_states (list[torch.Tensor]): Per-bin system indices into the loaded + state, used to materialize each batch lazily. current_state_bin (int): Index of the current batch being processed. Example:: @@ -549,9 +589,9 @@ def load_states(self, states: T | Sequence[T]) -> float: Args: states (SimState | list[SimState]): Collection of states to batch. Either a - list of individual SimState objects or a single batched SimState that - will be split into individual states. Each SimState has shape - information specific to its instance. + list of individual SimState objects or a single batched SimState. The + batcher works in units of optimizer groups (one per system by default), + so multi-system groups are kept together in the same batch. Returns: float: Maximum memory scaling metric that fits in GPU memory. @@ -575,7 +615,8 @@ def load_states(self, states: T | Sequence[T]) -> float: batched = ( states if isinstance(states, SimState) else ts.concatenate_states(states) ) - self.memory_scalers = calculate_memory_scalers( + self._batched = batched + self.memory_scalers = _group_memory_scalers( batched, self.memory_scales_with, self.cutoff ) if not self.max_memory_scaler: @@ -605,14 +646,25 @@ def load_states(self, states: T | Sequence[T]) -> float: index_bins = to_constant_volume_bins( self.index_to_scaler, max_volume=self.max_memory_scaler ) # list[dict[original_index: int, memory_scale:float]] - # Convert to list of lists of indices + # Convert to list of lists of group indices self.index_bins = [list(batch.keys()) for batch in index_bins] - self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] + # Per bin, the system indices for that bin's groups, ordered by group so the + # remapped group_idx after indexing matches the bin's group order. Materialized + # lazily in next_batch to avoid eagerly splitting one SimState per group. + self.batched_states = [ + torch.cat( + [ + torch.where(batched.group_idx == group_idx)[0] + for group_idx in index_bin + ] + ) + for index_bin in self.index_bins + ] self.current_state_bin = 0 logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", - len(self.memory_scalers), + batched.n_systems, len(self.index_bins), self.max_memory_scaler, ) @@ -641,8 +693,8 @@ def next_batch(self) -> tuple[T | None, list[int]]: # TODO: need to think about how this intersects with reporting too # TODO: definitely a clever treatment to be done with iterators here if self.current_state_bin < len(self.batched_states): - state_bin = self.batched_states[self.current_state_bin] - state = ts.concatenate_states(state_bin) + system_indices = self.batched_states[self.current_state_bin] + state = self._batched[system_indices] indices = ( self.index_bins[self.current_state_bin] if self.current_state_bin < len(self.index_bins) @@ -725,7 +777,7 @@ def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: ordered_results = batcher.restore_original_order(results) """ - all_states = [state.split() for state in batched_states] + all_states = [state.split_groups() for state in batched_states] all_states = list(chain.from_iterable(all_states)) original_indices = list(chain.from_iterable(self.index_bins)) @@ -881,9 +933,12 @@ def load_states(self, states: Sequence[T] | Iterator[T] | T) -> float | None: This method resets the current state indices and completed state tracking, so any ongoing processing will be restarted when this method is called. """ + state_units: Sequence[T] | Iterator[T] if isinstance(states, SimState): - states = states.split() - self.states_iterator = iter(states) + state_units = cast("T", states).split_groups() + else: + state_units = states + self.states_iterator = iter(state_units) self.current_scalers = [] self.current_idx = [] @@ -909,9 +964,7 @@ def _get_next_states(self) -> list[T]: new_idx: list[int] = [] new_states: list[T] = [] for state in self.states_iterator: - metric = calculate_memory_scalers( - state, self.memory_scales_with, self.cutoff - )[0] + metric = _unit_memory_scaler(state, self.memory_scales_with, self.cutoff) if metric > self.max_memory_scaler: raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" @@ -967,7 +1020,9 @@ def _get_first_batch(self) -> T: # we need to sample a state and use it to estimate the max metric # for the first batch first_state = next(self.states_iterator) - first_metric = calculate_memory_scalers(first_state, self.memory_scales_with)[0] + first_metric = _unit_memory_scaler( + first_state, self.memory_scales_with, self.cutoff + ) self.current_scalers += [first_metric] self.current_idx += [0] self.iteration_count.append(0) # Initialize attempt counter for first state @@ -1086,12 +1141,23 @@ def next_batch( # noqa: C901 if self.max_iterations is not None and ( self.iteration_count[abs_idx] >= self.max_iterations ): - # Force convergence for states that have reached max attempts - convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003 - - completed_idx = torch.where(convergence_tensor)[0].tolist() - - completed_states = updated_state.pop(completed_idx) + convergence_tensor[updated_state.group_idx == cur_idx] = True + + completed_idx = [] + completed_system_indices = [] + for group_idx in range(updated_state.n_groups): + system_mask = updated_state.group_idx == group_idx + if convergence_tensor[system_mask].all(): + completed_idx.append(group_idx) + completed_system_indices.extend(torch.where(system_mask)[0].tolist()) + + completed_states = ( + updated_state[completed_system_indices].split_groups() + if completed_system_indices + else [] + ) + if completed_system_indices: + updated_state.pop(completed_system_indices) # necessary to ensure states that finish at the same time are ordered properly completed_states.reverse() diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 8a58f8d40..b82701fad 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -462,9 +462,12 @@ class CellFireState(CellOptimState, FireState): _system_attributes = ( CellOptimState._system_attributes # noqa: SLF001 - | FireState._system_attributes # noqa: SLF001 | {"cell_velocities"} ) + _group_attributes = ( + CellOptimState._group_attributes # noqa: SLF001 + | FireState._group_attributes # noqa: SLF001 + ) @dataclass(kw_only=True) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index d92da8d41..03b1d96e4 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -18,6 +18,13 @@ from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs +def _group_sum( + values: torch.Tensor, group_idx: torch.Tensor, n_groups: int +) -> torch.Tensor: + summed = torch.zeros(n_groups, device=values.device, dtype=values.dtype) + return summed.scatter_add(0, group_idx, values) + + @dcite("10.1103/PhysRevLett.97.170201") def fire_init( state: SimState, @@ -60,7 +67,7 @@ def fire_init( device: torch.device = model.device dtype: torch.dtype = model.dtype - n_systems = state.n_systems + n_groups = state.n_groups # Get initial forces and energy from model model_output = model(state) @@ -72,12 +79,20 @@ def fire_init( dt_start_t = torch.as_tensor(dt_start, device=device, dtype=dtype) if dt_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill - dt_start_t = dt_start_t.expand(n_systems).clone() + dt_start_t = dt_start_t.expand(n_groups).clone() + elif dt_start_t.numel() != n_groups: + raise ValueError( + f"dt_start must have {n_groups} values, got {dt_start_t.numel()}" + ) alpha_start_t = torch.as_tensor(alpha_start, device=device, dtype=dtype) if alpha_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill - alpha_start_t = alpha_start_t.expand(n_systems).clone() + alpha_start_t = alpha_start_t.expand(n_groups).clone() + elif alpha_start_t.numel() != n_groups: + raise ValueError( + f"alpha_start must have {n_groups} values, got {alpha_start_t.numel()}" + ) # FIRE-specific additional attributes fire_attrs = { @@ -89,7 +104,7 @@ def fire_init( ), "dt": dt_start_t, "alpha": alpha_start_t, - "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + "n_pos": torch.zeros((n_groups,), device=model.device, dtype=torch.int32), } if cell_filter is not None: # Create cell optimization state @@ -189,7 +204,9 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 eps: float, ) -> T: """Perform one Velocity-Verlet based FIRE optimization step.""" - n_systems, device, dtype = state.n_systems, state.device, state.dtype + n_systems, n_groups = state.n_systems, state.n_groups + device, dtype = state.device, state.dtype + atom_group_idx = state.group_idx[state.system_idx] # Initialize velocities if NaN nan_velocities = state.velocities.isnan().any(dim=1) @@ -200,11 +217,11 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.cell_velocities[nan_cell_vel] = 0 alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype + (n_groups,), alpha_start.item(), device=device, dtype=dtype ) # First half of velocity update - atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) + atom_wise_dt = state.dt[atom_group_idx].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update @@ -227,15 +244,19 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 # Second half of velocity update state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if isinstance(state, CellFireState): - cell_wise_dt = state.dt.view(n_systems, 1, 1) + cell_wise_dt = state.dt[state.group_idx].view(n_systems, 1, 1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) # Calculate power - system_power = tsm.batched_vdot(state.forces, state.velocities, state.system_idx) + system_power = tsm.batched_vdot(state.forces, state.velocities, atom_group_idx) if isinstance(state, CellFireState): - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += _group_sum( + (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 @@ -252,32 +273,44 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 # Velocity mixing v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx + state.velocities, state.velocities, atom_group_idx ) - f_scaling_system = tsm.batched_vdot(state.forces, state.forces, state.system_idx) + f_scaling_system = tsm.batched_vdot(state.forces, state.forces, atom_group_idx) if isinstance(state, CellFireState): - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += _group_sum( + state.cell_velocities.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + f_scaling_system += _group_sum( + state.cell_forces.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_scaling_cell = torch.sqrt( + v_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) + f_scaling_cell = torch.sqrt( + f_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + alpha_cell_bc = state.alpha[state.group_idx].view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), + pos_mask_system[state.group_idx].view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[atom_group_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[atom_group_idx].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + alpha_atom = state.alpha[atom_group_idx].unsqueeze(-1) state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), + pos_mask_system[atom_group_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) @@ -301,7 +334,9 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 """Perform one ASE-style FIRE optimization step.""" from torch_sim.optimizers import CellFireState - n_systems, device, dtype = state.n_systems, state.device, state.dtype + n_systems, n_groups = state.n_systems, state.n_groups + device, dtype = state.device, state.dtype + atom_group_idx = state.group_idx[state.system_idx] # Per-atom NaN detection before zeroing: needed to decide whether to skip # FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap). @@ -315,7 +350,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 forces = state.forces else: alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype + (n_groups,), alpha_start.item(), device=device, dtype=dtype ) # Transform forces for cell optimization @@ -331,9 +366,13 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 forces = state.forces # Calculate power (newly zeroed systems will have power=0 → neg_mask) - system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) + system_power = tsm.batched_vdot(forces, state.velocities, atom_group_idx) if isinstance(state, CellFireState): - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += _group_sum( + (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 @@ -350,55 +389,74 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Velocity mixing BEFORE acceleration (ASE ordering) v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx + state.velocities, state.velocities, atom_group_idx ) - f_scaling_system = tsm.batched_vdot(forces, forces, state.system_idx) + f_scaling_system = tsm.batched_vdot(forces, forces, atom_group_idx) if isinstance(state, CellFireState): - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += _group_sum( + state.cell_velocities.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + f_scaling_system += _group_sum( + state.cell_forces.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_scaling_cell = torch.sqrt( + v_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) + f_scaling_cell = torch.sqrt( + f_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + alpha_cell_bc = state.alpha[state.group_idx].view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), + pos_mask_system[state.group_idx].view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[atom_group_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[atom_group_idx].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + alpha_atom = state.alpha[atom_group_idx].unsqueeze(-1) state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), + pos_mask_system[atom_group_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) - dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) - dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, state.system_idx) + state.velocities += forces * state.dt[atom_group_idx].unsqueeze(-1) + dr_atom = state.velocities * state.dt[atom_group_idx].unsqueeze(-1) + dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, atom_group_idx) if isinstance(state, CellFireState): - state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) - - dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) - dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) + cell_wise_dt = state.dt[state.group_idx].view(n_systems, 1, 1) + state.cell_velocities += state.cell_forces * cell_wise_dt + dr_cell = state.cell_velocities * cell_wise_dt + + dr_scaling_system += _group_sum( + dr_cell.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + dr_scaling_cell = torch.sqrt( + dr_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) + dr_scaling_atom = torch.sqrt(dr_scaling_system)[atom_group_idx].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index fb09795d0..2b0bc480b 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -48,7 +48,7 @@ class FireState(OptimState): n_pos: torch.Tensor _atom_attributes = OptimState._atom_attributes | {"velocities"} # noqa: SLF001 - _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 + _group_attributes = OptimState._group_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 @dataclass(kw_only=True) diff --git a/torch_sim/state.py b/torch_sim/state.py index 2c3e28b89..8b8adabe6 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -146,6 +146,8 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. + group_idx (torch.Tensor): Maps each system index to its optimizer group index. + Has shape (n_systems,). Defaults to one group per system. constraints (list["Constraint"] | None): List of constraints applied to the system. Constraints affect degrees of freedom and modify positions. @@ -180,6 +182,7 @@ class SimState: pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ + group_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ _constraints: list["Constraint"] = field(default_factory=list) _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) @@ -196,6 +199,7 @@ def __init__( # noqa: D107 pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, + group_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, **kwargs: Any, @@ -207,7 +211,8 @@ def __init__( # noqa: D107 "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell"} + _system_attributes: ClassVar[set[str]] = {"cell", "group_idx"} + _group_attributes: ClassVar[set[str]] = set() _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property @@ -251,6 +256,13 @@ def __setattr__(self, name: str, value: object) -> None: # noqa: C901 _, counts = torch.unique_consecutive(value, return_counts=True) if not torch.all(counts == torch.bincount(value)): raise ValueError("System indices must be unique consecutive integers") + elif name == "group_idx": + if isinstance(value, torch.Tensor): + value = value.to(dtype=torch.int64) + if value.ndim != 1: + raise ValueError("Group indices must be a 1D tensor") + if value.numel() > 0 and value.min() < 0: + raise ValueError("Group indices must be non-negative") super().__setattr__(name, value) def __post_init__(self) -> None: # noqa: C901 @@ -270,6 +282,10 @@ def __post_init__(self) -> None: # noqa: C901 # Get n_systems from system_idx (now guaranteed to be non-None) _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) n_systems = len(counts) + if self.group_idx is None: + self.group_idx = torch.arange( + n_systems, device=self.device, dtype=torch.int64 + ) if self.constraints: validate_constraints(self.constraints, state=self) @@ -285,6 +301,17 @@ def __post_init__(self) -> None: # noqa: C901 f"Cell must have shape (n_systems={n_systems}, 3, 3), " f"got {self.cell.shape}" ) + if self.group_idx.shape[0] != n_systems: + raise ValueError( + f"Group indices must have shape (n_systems={n_systems},), " + f"got {self.group_idx.shape}" + ) + unique_groups = torch.unique(self.group_idx) + expected_groups = torch.arange( + unique_groups.numel(), device=self.device, dtype=self.group_idx.dtype + ) + if not torch.equal(unique_groups, expected_groups): + raise ValueError("Group indices must be consecutive integers starting from 0") # if devices aren't all the same, raise an error, in a clean way devices = { @@ -296,6 +323,7 @@ def __post_init__(self) -> None: # noqa: C901 "atomic_numbers", "pbc", "system_idx", + "group_idx", ) } if len(set(devices.values())) > 1: @@ -313,6 +341,14 @@ def __post_init__(self) -> None: # noqa: C901 f"System extra '{key}' leading dim must be " f"n_systems={n_systems}, got {val.shape[0]}" ) + for name, val in get_attrs_for_scope(self, "per-group"): + if not isinstance(val, torch.Tensor): + continue + if val.shape[0] != self.n_groups: + raise ValueError( + f"Group attribute '{name}' leading dim must be " + f"n_groups={self.n_groups}, got {val.shape[0]}" + ) for key, val in self._atom_extras.items(): if key in all_attrs or hasattr(type(self), key): raise ValueError(f"Atom extra '{key}' shadows an existing attribute") @@ -330,6 +366,7 @@ def _get_all_attributes(cls) -> set[str]: return ( cls._atom_attributes | cls._system_attributes + | cls._group_attributes | cls._global_attributes | {"_constraints", "_system_extras", "_atom_extras"} ) @@ -439,6 +476,11 @@ def n_systems(self) -> int: """Number of systems in the system.""" return torch.unique(self.system_idx).shape[0] + @property + def n_groups(self) -> int: + """Number of optimizer groups in the state.""" + return int(self.group_idx.max().item()) + 1 if self.group_idx.numel() else 0 + @property def volume(self) -> torch.Tensor: """Volume of the system.""" @@ -636,6 +678,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: all_known = cls._get_all_attributes() n_atoms = state.n_atoms n_systems = state.n_systems + n_groups = state.n_groups for key, val in additional_attrs.items(): if key in all_known: attrs[key] = val @@ -645,6 +688,8 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: attrs.setdefault("_atom_extras", {})[key] = val elif leading == n_systems: attrs.setdefault("_system_extras", {})[key] = val + elif leading == n_groups: + attrs[key] = val else: raise ValueError(f"Attribute '{key}' has invalid leading dimension") else: @@ -687,16 +732,52 @@ def to_phonopy(self) -> list["PhonopyAtoms"]: """ return ts.io.state_to_phonopy(self) - def split(self) -> list[Self]: + def split(self, *, by: Literal["system", "group"] = "system") -> list[Self]: + """Split the SimState into a list of independent SimStates. + + Divides the current state into separate states, preserving all properties + appropriately for each piece. + + Args: + by ("system" | "group"): Whether to split into one state per system + (default) or one state per optimizer group. With the default + single-system groups the two are equivalent. + + Returns: + list[SimState]: A list of SimState objects, one per system or per group. + """ + if by == "system": + return self.split_systems() + if by == "group": + return self.split_groups() + raise ValueError(f"Invalid split mode {by!r}, must be 'system' or 'group'") + + def split_systems(self) -> list[Self]: """Split the SimState into a list of single-system SimStates. - Divides the current state into separate states, each containing a single system, - preserving all properties appropriately for each system. + Returns: + list[SimState]: A list of SimState objects, one per system. + """ + return _split_state(self, by="system") + + def split_groups(self) -> list[Self]: + """Split the SimState into a list of single-group SimStates. + + Each returned state holds all systems belonging to one optimizer group, with + ``group_idx`` remapped to a single group. With the default single-system + groups this is equivalent to :meth:`split_systems`. + + Requires ``group_idx`` to be non-decreasing so each group is a contiguous + block of systems (the layout produced by ``concatenate_states`` and the + autobatchers). Returns: - list[SimState]: A list of SimState objects, one per system + list[SimState]: A list of SimState objects, one per group. + + Raises: + ValueError: If ``group_idx`` is not non-decreasing. """ - return _split_state(self) + return _split_state(self, by="group") def pop(self, system_indices: int | list[int] | slice | torch.Tensor) -> list[Self]: """Pop off states with the specified system indices. @@ -778,7 +859,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx"} + exceptions = {"system_idx", "group_idx"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): @@ -803,13 +884,19 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: @classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( - cls._atom_attributes | cls._system_attributes | cls._global_attributes + cls._atom_attributes + | cls._system_attributes + | cls._group_attributes + | cls._global_attributes ) # 1) assert that no attribute is defined twice in all_defined_attributes duplicates = ( (cls._atom_attributes & cls._system_attributes) + | (cls._atom_attributes & cls._group_attributes) | (cls._atom_attributes & cls._global_attributes) + | (cls._system_attributes & cls._group_attributes) | (cls._system_attributes & cls._global_attributes) + | (cls._group_attributes & cls._global_attributes) ) if duplicates: raise TypeError( @@ -1004,13 +1091,13 @@ def _state_to_device[T: SimState]( # noqa: C901 def get_attrs_for_scope( - state: SimState, scope: Literal["per-atom", "per-system", "global"] + state: SimState, scope: Literal["per-atom", "per-system", "per-group", "global"] ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. Args: state (SimState): The state to get attributes for - scope (Literal["per-atom", "per-system", "global"]): The scope to get + scope (Literal["per-atom", "per-system", "per-group", "global"]): Scope to get attributes for Returns: @@ -1021,6 +1108,8 @@ def get_attrs_for_scope( attr_names = state._atom_attributes # noqa: SLF001 case "per-system": attr_names = state._system_attributes # noqa: SLF001 + case "per-group": + attr_names = state._group_attributes # noqa: SLF001 case "global": attr_names = state._global_attributes # noqa: SLF001 case _: @@ -1034,7 +1123,23 @@ def get_attrs_for_scope( yield from state.atom_extras.items() -def _filter_attrs_by_index( +def _remap_indices_by_first_occurrence( + indices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Remap integer indices to consecutive values in first-occurrence order.""" + unique_indices: list[int] = [] + remapped: list[int] = [] + for idx in indices.tolist(): + if idx not in unique_indices: + unique_indices.append(idx) + remapped.append(unique_indices.index(idx)) + return ( + torch.tensor(remapped, device=indices.device, dtype=torch.int64), + torch.tensor(unique_indices, device=indices.device, dtype=torch.int64), + ) + + +def _filter_attrs_by_index( # noqa: C901 state: SimState, atom_indices: torch.Tensor, system_indices: torch.Tensor, @@ -1060,12 +1165,17 @@ def _filter_attrs_by_index( atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) if len(system_indices) == 0: system_remap = torch.empty(0, device=state.device, dtype=torch.long) + old_group_indices = torch.empty(0, device=state.device, dtype=torch.long) + new_group_idx = torch.empty(0, device=state.device, dtype=torch.long) else: max_idx = int(system_indices.max().item()) + 1 system_remap = torch.empty(max_idx, device=state.device, dtype=torch.long) system_remap[system_indices] = torch.arange( len(system_indices), device=state.device ) + new_group_idx, old_group_indices = _remap_indices_by_first_occurrence( + state.group_idx[system_indices] + ) # select_constraint uses boolean masks (which lose ordering), so we must # remap constraint atom_idx / system_idx afterward to match the actual @@ -1097,8 +1207,16 @@ def _filter_attrs_by_index( for name, val in get_attrs_for_scope(state, "per-system"): if name in state.system_extras: continue + if name == "group_idx": + filtered_attrs[name] = new_group_idx + else: + filtered_attrs[name] = ( + val[system_indices] if isinstance(val, torch.Tensor) else val + ) + + for name, val in get_attrs_for_scope(state, "per-group"): filtered_attrs[name] = ( - val[system_indices] if isinstance(val, torch.Tensor) else val + val[old_group_indices] if isinstance(val, torch.Tensor) else val ) filtered_attrs["_system_extras"] = { @@ -1111,95 +1229,200 @@ def _filter_attrs_by_index( return filtered_attrs -def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 - """Split a SimState into a list of states, each containing a single system. +def _split_state[T: SimState]( # noqa: C901 + state: T, *, by: Literal["system", "group"] = "system" +) -> list[T]: + """Split a SimState into a list of states, one per system or per optimizer group. - Divides a multi-system state into individual single-system states, preserving - appropriate properties for each system. + Both modes share one vectorized ``torch.split`` pass over the state tensors. + ``by="system"`` yields one single-system state per system. ``by="group"`` yields + one state per group and requires ``group_idx`` to be non-decreasing so each group + is a contiguous block of systems (the layout produced by ``concatenate_states`` + and the autobatchers). Args: - state (SimState): The SimState to split + state (SimState): The SimState to split. + by (Literal["system", "group"]): The granularity to split at. Returns: - list[SimState]: A list of SimState objects, each containing a single - system + list[SimState]: One SimState per system or per group. + + Raises: + ValueError: For an unknown ``by`` value, or if ``by="group"`` and + ``group_idx`` is not non-decreasing. """ - system_sizes = state.n_atoms_per_system.tolist() + if by == "system": + system_to_output = torch.arange(state.n_systems, device=state.device) + n_outputs = state.n_systems + elif by == "group": + if state.n_systems > 1 and not bool( + (state.group_idx[1:] >= state.group_idx[:-1]).all() + ): + raise ValueError( + "split by group requires non-decreasing group_idx so each group is a " + "contiguous block of systems" + ) + system_to_output = state.group_idx + n_outputs = state.n_groups + else: + raise ValueError(f"Invalid split mode {by!r}, must be 'system' or 'group'") + + # Number of systems and atoms in each output, used to drive torch.split. + output_system_sizes = torch.bincount(system_to_output, minlength=n_outputs).tolist() + output_atom_sizes = ( + torch.bincount( + system_to_output, + weights=state.n_atoms_per_system.to(torch.float64), + minlength=n_outputs, + ) + .round() + .to(torch.int64) + .tolist() + ) split_per_atom = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"): - if attr_name == "system_idx" or attr_name in state.atom_extras: + if attr_name in state.atom_extras: continue - split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0) + split_per_atom[attr_name] = torch.split(attr_value, output_atom_sizes, dim=0) split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - if attr_name in state.system_extras: + if attr_name == "group_idx" or attr_name in state.system_extras: continue if isinstance(attr_value, torch.Tensor): - split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) + split_per_system[attr_name] = torch.split( + attr_value, output_system_sizes, dim=0 + ) else: # Non-tensor attributes are replicated for each split - split_per_system[attr_name] = [attr_value] * state.n_systems + split_per_system[attr_name] = [attr_value] * n_outputs + + split_per_group = {} + per_group_attrs = list(get_attrs_for_scope(state, "per-group")) + # When splitting by system, several systems can share a group, so each output maps + # to its group's row; when splitting by group, each output is exactly one group. + group_ids = state.group_idx.tolist() if (per_group_attrs and by == "system") else [] + for attr_name, attr_value in per_group_attrs: + if not isinstance(attr_value, torch.Tensor): + split_per_group[attr_name] = [attr_value] * n_outputs + elif by == "group": + split_per_group[attr_name] = list(torch.split(attr_value, 1, dim=0)) + else: + split_per_group[attr_name] = [attr_value[g : g + 1] for g in group_ids] global_attrs = dict(get_attrs_for_scope(state, "global")) - split_system_extras: dict[str, list[torch.Tensor]] = {} - for key, val in state.system_extras.items(): - split_system_extras[key] = list(torch.split(val, 1, dim=0)) + split_system_extras = { + key: torch.split(val, output_system_sizes, dim=0) + for key, val in state.system_extras.items() + } + split_atom_extras = { + key: torch.split(val, output_atom_sizes, dim=0) + for key, val in state.atom_extras.items() + } - split_atom_extras: dict[str, list[torch.Tensor]] = {} - for key, val in state.atom_extras.items(): - split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + atom_offsets = [0] + for size in output_atom_sizes: + atom_offsets.append(atom_offsets[-1] + size) + system_offsets = [0] + for size in output_system_sizes: + system_offsets.append(system_offsets[-1] + size) - # Create a state for each system states: list[T] = [] - n_systems = len(system_sizes) - zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) - cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) - for sys_idx in range(n_systems): - # Build per-system attributes (padded attributes stay padded for consistency) - per_system_dict = { - attr_name: split_per_system[attr_name][sys_idx] - for attr_name in split_per_system - } + for out_idx in range(n_outputs): + # Rebase per-atom system_idx to 0-based within the output. Atoms are ordered by + # system, so the first entry is the output's lowest original system index. + sys_chunk = split_per_atom["system_idx"][out_idx] + local_system_idx = sys_chunk - sys_chunk[0] if sys_chunk.numel() else sys_chunk system_attrs = { - # Create a system tensor with all zeros for this system - "system_idx": torch.zeros( - system_sizes[sys_idx], device=state.device, dtype=torch.int64 + "system_idx": local_system_idx, + "group_idx": torch.zeros( + output_system_sizes[out_idx], device=state.device, dtype=torch.int64 ), - # Add the split per-atom attributes **{ - attr_name: split_per_atom[attr_name][sys_idx] + attr_name: split_per_atom[attr_name][out_idx] for attr_name in split_per_atom + if attr_name != "system_idx" + }, + **{ + attr_name: split_per_system[attr_name][out_idx] + for attr_name in split_per_system + }, + **{ + attr_name: split_per_group[attr_name][out_idx] + for attr_name in split_per_group }, - # Add the split per-system attributes (with unpadding applied) - **per_system_dict, - # Add the global attributes **global_attrs, "_system_extras": { - key: split_system_extras[key][sys_idx] for key in split_system_extras + key: split_system_extras[key][out_idx] for key in split_system_extras }, "_atom_extras": { - key: split_atom_extras[key][sys_idx] for key in split_atom_extras + key: split_atom_extras[key][out_idx] for key in split_atom_extras }, } - - start_idx = int(cumsum_atoms[sys_idx].item()) - end_idx = int(cumsum_atoms[sys_idx + 1].item()) - atom_idx = torch.arange(start_idx, end_idx, device=state.device) - new_constraints: list[Constraint] = [] - for constraint in state.constraints: - sub = constraint.select_sub_constraint(atom_idx, sys_idx) - if sub is not None: - new_constraints.append(sub) - - system_attrs["_constraints"] = new_constraints + system_attrs["_constraints"] = _split_constraints( + state, + out_idx, + by=by, + atom_offsets=atom_offsets, + system_offsets=system_offsets, + ) states.append(type(state)(**system_attrs)) return states +def _split_constraints( + state: SimState, + out_idx: int, + *, + by: str, + atom_offsets: list[int], + system_offsets: list[int], +) -> list[Constraint]: + """Select the constraints for one output of :func:`_split_state`.""" + if not state.constraints: + return [] + + atom_idx = torch.arange( + atom_offsets[out_idx], atom_offsets[out_idx + 1], device=state.device + ) + if by == "system": + return [ + sub + for constraint in state.constraints + if (sub := constraint.select_sub_constraint(atom_idx, out_idx)) is not None + ] + + # A group spans several systems: select each system's sub-constraints (local to that + # system) and merge them into the group's frame, mirroring concatenate_states. + per_system_constraints: list[list[Constraint]] = [] + per_system_atoms: list[int] = [] + atom_cursor = atom_offsets[out_idx] + for sys_idx in range(system_offsets[out_idx], system_offsets[out_idx + 1]): + n_atoms = int(state.n_atoms_per_system[sys_idx].item()) + sys_atom_idx = torch.arange( + atom_cursor, atom_cursor + n_atoms, device=state.device + ) + per_system_constraints.append( + [ + sub + for constraint in state.constraints + if (sub := constraint.select_sub_constraint(sys_atom_idx, sys_idx)) + is not None + ] + ) + per_system_atoms.append(n_atoms) + atom_cursor += n_atoms + + return merge_constraints( + per_system_constraints, + torch.tensor(per_system_atoms, device=state.device), + torch.ones(len(per_system_atoms), device=state.device, dtype=torch.int64), + ) + + def _pop_states[T: SimState]( state: T, pop_indices: list[int] | torch.Tensor ) -> tuple[T, list[T]]: @@ -1329,10 +1552,13 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + per_group_tensors = defaultdict(list) system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] + new_group_indices = [] system_offset = 0 + group_offset = 0 num_atoms_per_state = [] # Process all states in a single pass @@ -1350,10 +1576,13 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): - if prop in state.system_extras: + if prop == "group_idx" or prop in state.system_extras: continue per_system_tensors[prop].append(val) + for prop, val in get_attrs_for_scope(state, "per-group"): + per_group_tensors[prop].append(val) + # Collect extras for key, val in state.system_extras.items(): system_extras_tensors[key].append(val) @@ -1364,9 +1593,11 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + new_group_indices.append(state.group_idx + group_offset) num_atoms_per_state.append(state.n_atoms) system_offset += num_systems + group_offset += state.n_groups # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): @@ -1427,8 +1658,16 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] + for prop, tensors in per_group_tensors.items(): + concatenated[prop] = ( + torch.cat(tensors, dim=0) + if isinstance(tensors[0], torch.Tensor) + else tensors[0] + ) + # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + concatenated["group_idx"] = torch.cat(new_group_indices) # Concatenate extras concatenated["_system_extras"] = { diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py new file mode 100644 index 000000000..f762a238f --- /dev/null +++ b/torch_sim/workflows/neb.py @@ -0,0 +1,442 @@ +"""Nudged Elastic Band (NEB) workflow.""" + +import inspect +import logging +from collections.abc import Callable +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Literal + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import ( + OptimState, + fire_init, + fire_step, + gradient_descent_init, + gradient_descent_step, +) +from torch_sim.runners import optimize +from torch_sim.state import SimState, concatenate_states, initialize_state +from torch_sim.transforms import minimum_image_displacement +from torch_sim.typing import StateLike + + +logger = logging.getLogger(__name__) + +_EPS = torch.finfo(torch.float64).eps + +OptimizerType = Literal["fire", "gd", "ase_fire"] + + +def _extract_kwargs_from_params( + params: dict[str, Any], func: Callable[..., Any], exclude: set[str] | None = None +) -> dict[str, Any]: + """Return the entries in ``params`` accepted by ``func``.""" + exclude = exclude or {"state", "model"} + sig = inspect.signature(func) + return {k: v for k, v in params.items() if k in sig.parameters and k not in exclude} + + +@dataclass(frozen=True) +class _OptimizerConfig: + """Functional optimizer pair and argument modifiers.""" + + init_fn: Callable[..., OptimState] + step_fn: Callable[..., OptimState] + init_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + step_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + + +_OPTIMIZER_REGISTRY: dict[OptimizerType, _OptimizerConfig] = { + "fire": _OptimizerConfig(init_fn=fire_init, step_fn=fire_step), + "gd": _OptimizerConfig( + init_fn=gradient_descent_init, + step_fn=gradient_descent_step, + step_kwargs_modifier=lambda kwargs: ( + kwargs if "pos_lr" in kwargs else {**kwargs, "pos_lr": kwargs.get("lr", 0.01)} + ), + ), + "ase_fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + init_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + step_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + ), +} + + +def validate_endpoints(initial_state: SimState, final_state: SimState) -> None: + """Validate that endpoints define a fixed-cell single-chain NEB path.""" + if initial_state.n_systems != 1 or final_state.n_systems != 1: + raise ValueError("Initial and final states must each contain one system.") + if initial_state.n_atoms != final_state.n_atoms: + raise ValueError( + f"Initial ({initial_state.n_atoms}) and final ({final_state.n_atoms}) " + "states must have the same number of atoms." + ) + if not torch.equal(initial_state.atomic_numbers, final_state.atomic_numbers): + raise ValueError("Initial and final states must have the same atom types.") + if not torch.equal(initial_state.pbc, final_state.pbc): + raise ValueError("Initial and final states must have the same PBC setting.") + if not torch.allclose(initial_state.cell, final_state.cell): + raise ValueError("Fixed-cell NEB requires matching endpoint cells.") + + +def interpolate_path( + initial_state: SimState, final_state: SimState, n_images: int +) -> SimState: + """Linearly interpolate movable NEB images using the minimum image convention.""" + validate_endpoints(initial_state, final_state) + if n_images < 1: + raise ValueError("n_images must be at least 1.") + + n_atoms = initial_state.n_atoms + displacement = minimum_image_displacement( + dr=final_state.positions - initial_state.positions, + cell=initial_state.cell[0], + pbc=initial_state.pbc, + ).reshape(n_atoms, 3) + factors = torch.linspace( + 0.0, + 1.0, + steps=n_images + 2, + device=initial_state.device, + dtype=initial_state.dtype, + )[1:-1] + positions = ( + initial_state.positions.unsqueeze(0) + + factors.view(-1, 1, 1) * displacement.unsqueeze(0) + ).reshape(-1, 3) + system_idx = torch.repeat_interleave( + torch.arange(n_images, device=initial_state.device, dtype=torch.int64), + repeats=n_atoms, + ) + return SimState( + positions=positions, + masses=initial_state.masses.repeat(n_images), + cell=initial_state.cell.repeat(n_images, 1, 1), + pbc=initial_state.pbc, + atomic_numbers=initial_state.atomic_numbers.repeat(n_images), + system_idx=system_idx, + group_idx=torch.zeros(n_images, device=initial_state.device, dtype=torch.int64), + ) + + +def as_sim_state(state: SimState) -> SimState: + """Drop optimizer-only fields while preserving the atomistic state.""" + return SimState.from_state(state) + + +def assemble_path( + initial_state: SimState, movable_state: SimState, final_state: SimState +) -> SimState: + """Return the full NEB path as endpoints plus movable images.""" + path = concatenate_states( + [ + as_sim_state(initial_state), + as_sim_state(movable_state), + as_sim_state(final_state), + ] + ) + path.group_idx = torch.zeros(path.n_systems, device=path.device, dtype=torch.long) + return path + + +def compute_tangents( + all_positions: torch.Tensor, + all_energies: torch.Tensor, + cell: torch.Tensor, + *, + pbc: torch.Tensor, +) -> torch.Tensor: + """Compute improved normalized tangents for the intermediate NEB images.""" + n_total_images, n_atoms, _ = all_positions.shape + n_intermediate = n_total_images - 2 + tangents = torch.zeros( + (n_intermediate, n_atoms, 3), + device=all_positions.device, + dtype=all_positions.dtype, + ) + displacements = minimum_image_displacement( + dr=all_positions[1:] - all_positions[:-1], + cell=cell, + pbc=pbc, + ).reshape(n_total_images - 1, n_atoms, 3) + dE_forward = all_energies[1:] - all_energies[:-1] + + for i in range(n_intermediate): + image_idx = i + 1 + dR_plus = displacements[image_idx] + dR_minus = displacements[image_idx - 1] + dE_plus = dE_forward[image_idx] + dE_minus = dE_forward[image_idx - 1] + + if dE_plus > 0 and dE_minus > 0: + tangent = dR_plus + elif dE_plus < 0 and dE_minus < 0: + tangent = dR_minus + else: + abs_dE_plus = torch.abs(dE_plus) + abs_dE_minus = torch.abs(dE_minus) + delta_max = torch.maximum(abs_dE_plus, abs_dE_minus) + delta_min = torch.minimum(abs_dE_plus, abs_dE_minus) + if (dE_plus + dE_minus) > 0: + tangent = dR_plus * delta_max + dR_minus * delta_min + else: + tangent = dR_plus * delta_min + dR_minus * delta_max + + norm = torch.linalg.norm(tangent) + if norm > _EPS: + tangents[i] = tangent / norm + + return tangents + + +def calculate_neb_forces( + path_state: SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + *, + spring_constant: float, + use_climbing_image: bool, +) -> torch.Tensor: + """Calculate NEB forces for the movable images in a single path.""" + n_total_images = path_state.n_systems + n_intermediate = n_total_images - 2 + if n_intermediate <= 0: + raise ValueError("A NEB path must include at least one movable image.") + if path_state.n_atoms % n_total_images != 0: + raise ValueError("NEB path images must contain the same number of atoms.") + if true_energies.shape[0] != n_intermediate: + raise ValueError(f"{true_energies.shape[0]=} does not match {n_intermediate=}.") + + n_atoms = path_state.n_atoms // n_total_images + all_positions = path_state.positions.reshape(n_total_images, n_atoms, 3) + all_energies = torch.cat( + [initial_energy.reshape(1), true_energies, final_energy.reshape(1)] + ) + true_forces_by_image = true_forces.reshape(n_intermediate, n_atoms, 3) + cell = path_state.cell[0] + + tangents = compute_tangents( + all_positions, + all_energies, + cell, + pbc=path_state.pbc, + ) + displacements = minimum_image_displacement( + dr=all_positions[1:] - all_positions[:-1], + cell=cell, + pbc=path_state.pbc, + ).reshape(n_total_images - 1, n_atoms, 3) + segment_lengths = torch.linalg.norm(displacements, dim=(-1, -2)) + + true_dot_tangent = (true_forces_by_image * tangents).sum(dim=(-1, -2), keepdim=True) + perpendicular_forces = true_forces_by_image - true_dot_tangent * tangents + spring_magnitude = spring_constant * (segment_lengths[1:] - segment_lengths[:-1]) + spring_forces = spring_magnitude.view(-1, 1, 1) * tangents + neb_forces = perpendicular_forces + spring_forces + + if use_climbing_image: + climbing_idx = int(torch.argmax(true_energies).item()) + neb_forces[climbing_idx] = true_forces_by_image[climbing_idx] - ( + 2 * true_dot_tangent[climbing_idx] * tangents[climbing_idx] + ) + + return neb_forces.reshape(-1, 3) + + +def _endpoint_energies( + initial_state: SimState, final_state: SimState, model: ModelInterface +) -> tuple[torch.Tensor, torch.Tensor]: + return ( + model(as_sim_state(initial_state))["energy"][0], + model(as_sim_state(final_state))["energy"][0], + ) + + +def _store_neb_force_metadata(state: OptimState, neb_forces: torch.Tensor) -> None: + state.forces = neb_forces + state.neb_forces = neb_forces + state.neb_max_force = torch.linalg.norm(neb_forces, dim=-1).max() + + +def neb_init( + state: SimState, + model: ModelInterface, + *, + initial_state: SimState, + final_state: SimState, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + base_init_fn: Callable[..., OptimState], + base_init_kwargs: dict[str, Any] | None = None, + spring_constant: float = 0.1, + use_climbing_image: bool = False, +) -> OptimState: + """Initialize the base optimizer state and replace true forces with NEB forces.""" + opt_state = base_init_fn(state, model, **(base_init_kwargs or {})) + full_path = assemble_path(initial_state, opt_state, final_state) + neb_forces = calculate_neb_forces( + full_path, + opt_state.forces, + opt_state.energy, + initial_energy, + final_energy, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + _store_neb_force_metadata(opt_state, neb_forces) + return opt_state + + +def neb_step( + state: OptimState, + model: ModelInterface, + *, + initial_state: SimState, + final_state: SimState, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + base_step_fn: Callable[..., OptimState], + base_step_kwargs: dict[str, Any] | None = None, + spring_constant: float = 0.1, + use_climbing_image: bool = False, +) -> OptimState: + """Advance one NEB step by delegating position updates to a base optimizer.""" + state = base_step_fn(state, model, **(base_step_kwargs or {})) + true_forces = state.forces.clone() + full_path = assemble_path(initial_state, state, final_state) + neb_forces = calculate_neb_forces( + full_path, + true_forces, + state.energy, + initial_energy, + final_energy, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + state.true_forces = true_forces + _store_neb_force_metadata(state, neb_forces) + return state + + +def neb_convergence_fn( + state: OptimState, last_energy: torch.Tensor, *, fmax: float +) -> torch.Tensor: + """Return all-or-nothing NEB convergence for the movable images.""" + del last_energy + converged = torch.linalg.norm(state.forces, dim=-1).max() < fmax + return converged.expand(state.n_systems) + + +@dataclass +class NEB: + """Single-chain Nudged Elastic Band workflow.""" + + model: ModelInterface + n_images: int + spring_constant: float = 0.1 + use_climbing_image: bool = False + optimizer_type: OptimizerType = "ase_fire" + optimizer_params: dict[str, Any] = field(default_factory=dict) + trajectory_filename: str | None = None + device: torch.device | None = None + dtype: torch.dtype | None = None + + def __post_init__(self) -> None: + """Initialize device, dtype, and optimizer configuration.""" + if self.device is None: + self.device = self.model.device + if self.dtype is None: + self.dtype = self.model.dtype + if self.optimizer_type not in _OPTIMIZER_REGISTRY: + raise ValueError( + f"Unsupported optimizer_type={self.optimizer_type!r}; expected one of " + f"{list(_OPTIMIZER_REGISTRY)}." + ) + + config = _OPTIMIZER_REGISTRY[self.optimizer_type] + init_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.init_fn, exclude={"state", "model"} + ) + step_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.step_fn, exclude={"state", "model"} + ) + if config.init_kwargs_modifier is not None: + init_kwargs = config.init_kwargs_modifier(init_kwargs) + if config.step_kwargs_modifier is not None: + step_kwargs = config.step_kwargs_modifier(step_kwargs) + + self._init_fn = config.init_fn + self._step_fn = config.step_fn + self._init_kwargs = init_kwargs + self._step_kwargs = step_kwargs + + def run( + self, + initial_system: StateLike, + final_system: StateLike, + max_steps: int = 100, + fmax: float = 0.05, + ) -> SimState: + """Run a single-chain NEB optimization through ``ts.optimize``.""" + logger.info("Starting NEB optimization") + initial_state = initialize_state(initial_system, self.device, self.dtype) + final_state = initialize_state(final_system, self.device, self.dtype) + validate_endpoints(initial_state, final_state) + initial_energy, final_energy = _endpoint_energies( + initial_state, final_state, self.model + ) + movable_images = interpolate_path(initial_state, final_state, self.n_images) + logger.info( + "Running NEB for max %d steps or fmax < %.4f eV/Ang.", + max_steps, + fmax, + ) + + endpoint_kwargs: dict[str, Any] = { + "initial_state": as_sim_state(initial_state), + "final_state": as_sim_state(final_state), + "initial_energy": initial_energy, + "final_energy": final_energy, + "spring_constant": self.spring_constant, + "use_climbing_image": self.use_climbing_image, + } + trajectory_reporter = ( + {"filenames": self.trajectory_filename} + if self.trajectory_filename is not None + else None + ) + opt_state = optimize( + movable_images, + self.model, + optimizer=(neb_init, neb_step), + convergence_fn=partial(neb_convergence_fn, fmax=fmax), + max_steps=max_steps, + steps_between_swaps=1, + trajectory_reporter=trajectory_reporter, + autobatcher=False, + init_kwargs={ + **endpoint_kwargs, + "base_init_fn": self._init_fn, + "base_init_kwargs": self._init_kwargs, + }, + **endpoint_kwargs, + base_step_fn=self._step_fn, + base_step_kwargs=self._step_kwargs, + ) + final_neb_max_force = torch.linalg.norm(opt_state.forces, dim=-1).max() + if final_neb_max_force >= fmax: + logger.warning("NEB optimization did not converge within max_steps.") + else: + logger.info("NEB optimization converged.") + return assemble_path(initial_state, opt_state, final_state)