From 5e5140ae5e7234e238dea363d20bd979408c6064 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 28 Feb 2026 17:39:49 -0500 Subject: [PATCH 1/7] wip --- .gitignore | 3 + examples/readme.md | 5 +- examples/scripts/9_neb.py | 934 +++++++++++++++++++++++++++++++++++++ torch_sim/workflows/neb.py | 887 +++++++++++++++++++++++++++++++++++ 4 files changed, 1826 insertions(+), 3 deletions(-) create mode 100644 examples/scripts/9_neb.py create mode 100644 torch_sim/workflows/neb.py diff --git a/.gitignore b/.gitignore index 01cde4042..6b62d10b6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,3 +43,6 @@ uv.lock # duecredit .duecredit.p + +# ignore local users potential agent files/plans +.agents/ diff --git a/examples/readme.md b/examples/readme.md index 7e78979eb..06f778b33 100644 --- a/examples/readme.md +++ b/examples/readme.md @@ -35,7 +35,6 @@ If you'd like to execute the scripts or examples locally, you can run them with: curl -LsSf https://astral.sh/uv/install.sh | sh # pick any of the examples -uv run --with . examples/2_Structural_optimization/2.3_MACE_FIRE.py -uv run --with . examples/3_Dynamics/3.3_MACE_NVE_cueq.py -uv run --with . examples/4_High_level_api/4.1_high_level_api.py +uv run --with . examples/scripts/1_introduction.py +uv run --with . examples/scripts/2_structural_optimization.py ``` diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py new file mode 100644 index 000000000..aa26eff72 --- /dev/null +++ b/examples/scripts/9_neb.py @@ -0,0 +1,934 @@ +"""Nudged Elastic Band (NEB) workflow. + +This script demonstrates the Nudged Elastic Band method for finding minimum energy +paths between two given atomic configurations. +""" +# %% +# /// script +# dependencies = [ +# "mace-torch>=0.3.12", +# "ase", +# ] +# /// + +import json # Import json for output + +# Configure logging to DEBUG level first +import logging +import pickle # Import pickle + +import ase.geometry # Import the geometry module +import h5py +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase.build import bulk +from ase.io import read +from ase.mep import NEB as ASENEB +from ase.mep.neb import ImprovedTangentMethod, NEBState +from ase.optimize import FIRE +from mace.calculators.foundations_models import mace_mp +from mace.calculators.mace import MACECalculator +from monty.json import MontyDecoder, MontyEncoder # Import Monty + +import torch_sim as ts +from torch_sim.models.mace import MaceModel, MaceUrls +from torch_sim.state import SimState +from torch_sim.workflows.neb import NEB as TorchNEB + + +# Redirect logging to a file instead of stdout +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(message)s", + filename="neb_debug.log", # Specify the log file name + filemode="w", +) # Overwrite the log file each time +logging.getLogger("torch_sim.workflows.neb").setLevel(logging.DEBUG) + + +torch_sim_device = "cuda" if torch.cuda.is_available() else "cpu" +torch_sim_dtype = torch.float64 # Use float64 for higher precision + +# Load MACE model using mace_mp like other tutorials +print("Loading MACE model...") +mace_potential = mace_mp( + model=MaceUrls.mace_mpa_medium, + return_raw_model=True, + default_dtype=str(torch_sim_dtype).removeprefix("torch."), + device=str(torch_sim_device), +) + + +def compare_initial_paths( + ase_start_atoms, + ase_end_atoms, + torch_sim_initial_state: SimState, + torch_sim_final_state: SimState, + neb_workflow: TorchNEB, +): + """Compares initial paths and the MIC displacement vector.""" + print("Comparing initial interpolated paths and MIC vectors...") + n_images = neb_workflow.n_images + n_total_images = n_images + 2 + device = neb_workflow.device + dtype = neb_workflow.dtype + + # --- Endpoint Check --- + print("\nChecking consistency of starting endpoint positions:") + ase_start_pos_direct = ase_start_atoms.get_positions() + ts_start_pos_direct = torch_sim_initial_state.positions.cpu().numpy() + start_close = np.allclose( + ase_start_pos_direct, ts_start_pos_direct, rtol=1e-5, atol=1e-6 + ) + print(f" Direct Start positions close: {start_close}") + if not start_close: + max_diff_start = np.max(np.abs(ase_start_pos_direct - ts_start_pos_direct)) + print(f" Max absolute difference (Start): {max_diff_start:.6f}") + print("------------------------------------") + + # --- MIC Vector Comparison --- + print("\nComparing Minimum Image Convention (MIC) displacement vectors:") + # Use the torch-sim states as the source of truth for positions/cell + raw_dr_ts = torch_sim_final_state.positions - torch_sim_initial_state.positions + cell_ts = torch_sim_initial_state.cell[0] # Assuming single batch cell + pbc_ts = torch_sim_initial_state.pbc + + # ASE MIC calculation + try: + ase_cell_np = cell_ts.cpu().numpy() + ase_pbc_np = np.array([pbc_ts] * 3) # ASE expects 3 bools usually + ase_mic_dr_np, _ = ase.geometry.find_mic( + raw_dr_ts.cpu().numpy(), ase_cell_np, pbc=ase_pbc_np + ) + print(f" ASE MIC vector calculated (shape: {ase_mic_dr_np.shape})") + except Exception as e: + print(f" Error calculating ASE MIC: {e}") + ase_mic_dr_np = None + + # torch-sim MIC calculation + try: + ts_mic_dr = ts.transforms.minimum_image_displacement( + dr=raw_dr_ts, cell=cell_ts, pbc=pbc_ts + ) + ts_mic_dr_np = ts_mic_dr.cpu().numpy() + print(f" torch-sim MIC vector calculated (shape: {ts_mic_dr_np.shape})") + except Exception as e: + print(f" Error calculating torch-sim MIC: {e}") + ts_mic_dr_np = None + + # Compare the MIC vectors + if ase_mic_dr_np is not None and ts_mic_dr_np is not None: + if ase_mic_dr_np.shape != ts_mic_dr_np.shape: + print(" Error: Shapes of MIC vectors do not match.") + else: + mic_vectors_close = np.allclose( + ase_mic_dr_np, ts_mic_dr_np, rtol=1e-5, atol=1e-6 + ) + print(f" MIC displacement vectors close: {mic_vectors_close}") + if not mic_vectors_close: + max_diff_mic = np.max(np.abs(ase_mic_dr_np - ts_mic_dr_np)) + norm_diff = np.linalg.norm(ase_mic_dr_np - ts_mic_dr_np) + print(f" Max absolute difference (MIC vectors): {max_diff_mic:.6f}") + print(f" Norm of difference vector (MIC): {norm_diff:.6f}") + print(" This difference likely causes the interpolation discrepancy.") + print("------------------------------------") + + # --- Get ASE interpolated path --- + ase_images = [ase_start_atoms.copy() for _ in range(n_images + 1)] + ase_images.append(ase_end_atoms.copy()) + ase_neb_calc = ASENEB(ase_images, climb=False) + ase_neb_calc.interpolate(mic=True) + ase_positions = np.stack([img.get_positions() for img in ase_neb_calc.images]) + print(f"\n ASE interpolated path shape: {ase_positions.shape}") + + # --- Get torch-sim interpolated path --- + try: + interpolated_state = neb_workflow._interpolate_path( + torch_sim_initial_state, torch_sim_final_state + ) + ts_interp_pos = interpolated_state.positions + ts_start_pos = torch_sim_initial_state.positions + ts_end_pos = torch_sim_final_state.positions + n_atoms = ts_start_pos.shape[0] + ts_interp_pos_reshaped = ts_interp_pos.reshape(n_images, n_atoms, 3) + ts_positions = torch.cat( + [ + torch_sim_initial_state.positions.unsqueeze(0).to(device, dtype), + ts_interp_pos_reshaped.to(device, dtype), + torch_sim_final_state.positions.unsqueeze(0).to(device, dtype), + ], + dim=0, + ) + ts_positions_np = ts_positions.cpu().numpy() + print(f" torch-sim interpolated path shape (direct): {ts_positions_np.shape}") + except Exception as e: + print(f" Error during torch-sim interpolation: {e}") + import traceback + + traceback.print_exc() + return + + # --- Compare Interpolated Paths --- + print( + "\n Per-image comparison of interpolated paths (Max Abs Error | Mean Abs Error):" + ) + overall_max_diff_interp = 0.0 + if ase_positions.shape != ts_positions_np.shape: + print(" Error: Shapes of ASE and torch-sim interpolated paths do not match.") + return + + for i in range(n_total_images): + diff_image_i = np.abs(ase_positions[i] - ts_positions_np[i]) + max_ae_i = np.max(diff_image_i) + mae_i = np.mean(diff_image_i) + print(f" Image {i}: MaxAE = {max_ae_i:.6f} | MAE = {mae_i:.6f}") + overall_max_diff_interp = max(overall_max_diff_interp, max_ae_i) + + are_close_interp = np.allclose(ase_positions, ts_positions_np, rtol=1e-5, atol=1e-6) + + if are_close_interp: + print(" Overall: Interpolated paths are numerically close.") + else: + print(" Overall: Interpolated paths differ numerically.") + print( + f" Overall Maximum absolute difference (Interpolated): {overall_max_diff_interp:.6f}" + ) + + +def ase_neb(start_atoms, end_atoms, nimages=5): + device = "cuda" if torch.cuda.is_available() else "cpu" + images = [start_atoms.copy() for _ in range(nimages + 1)] + images.append(end_atoms.copy()) + + neb_calc = ASENEB(images, climb=True, method="improvedtangent") + neb_calc.interpolate(mic=True) + + # Attach calculator to all images using mace_mp + ase_dtype_str = "float64" if torch_sim_dtype == torch.float64 else "float32" + print(f"Attaching ASE calculator with dtype: {ase_dtype_str} to all images") + ase_calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=ase_dtype_str, + dispersion=False, + ) + for image in neb_calc.images: + image.calc = ase_calc + + # Set up trajectory logging for the reference ASE run (Commented out as not used for plot) + # ase_traj_filename = "ase_ref_neb.traj" + opt = FIRE(neb_calc) + # opt.attach(traj) # Attach the trajectory logger + + # Run the ASE optimization (essential) + print("Running ASE NEB optimization...") + opt.run(fmax=0.05, steps=1000) + print("Finished ASE NEB optimization.") + + return neb_calc # Only return the final NEB object + + +def relax_atoms( + atoms, + fmax=0.05, + steps=1000, + device=torch_sim_device, + dtype=torch_sim_dtype, +): + new_atoms = atoms.copy() + ase_dtype_str = "float64" if dtype == torch.float64 else "float32" + new_atoms.calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=str(device), + default_dtype=ase_dtype_str, + dispersion=False, + ) + opt = FIRE(new_atoms) + opt.run(fmax=fmax, steps=steps) + return new_atoms + + +# Create the torch_sim wrapper +ts_mace_model = MaceModel( + model=mace_potential, + device=torch_sim_device, + dtype=torch_sim_dtype, + compute_forces=True, # Default, but good to be explicit + compute_stress=True, # Needed by interface if we want stress later + enable_cueq=False, +) + +# initial_trajectory = read('/home/myless/Packages/forge/scratch/data/neb_workflow_data/Cr7Ti8V104W8Zr_Cr_to_V_site102_to_69_initial.xyz', index=':') +# print(len(initial_trajectory)) + +# Create simple test structures for demonstration +# Using bulk structures instead of file paths +# Create simple test structures (can be replaced with file reads if needed) +start_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) +end_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) +# Add a small displacement to create a path +end_atoms.positions[0] += [0.1, 0.1, 0.1] + +relaxed_start_atoms = relax_atoms(start_atoms) +relaxed_end_atoms = relax_atoms(end_atoms) + +traj_file_name = "neb_path_torchsim_fire_5im.hdf5" + +# --- Setup ASE NEB for comparison --- +n_intermediate_images_ase = 5 +ase_images_compare = [relaxed_start_atoms.copy()] +ase_images_compare.extend( + [relaxed_start_atoms.copy() for _ in range(n_intermediate_images_ase)] +) +ase_images_compare.append(relaxed_end_atoms.copy()) + +ase_neb_compare = ASENEB( + ase_images_compare, + k=0.1, # Match torch-sim spring constant + climb=True, # Match torch-sim setting + method="improvedtangent", # Match torch-sim tangent method +) +ase_neb_compare.interpolate(mic=True) # Initial interpolation + +device = "cuda" if torch.cuda.is_available() else "cpu" +# Attach calculator to ALL ASE images using mace_mp +ase_dtype_str_compare = "float64" if torch_sim_dtype == torch.float64 else "float32" +print(f"Using ASE comparison calculator dtype: {ase_dtype_str_compare}") +ase_calculator = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=ase_dtype_str_compare, + dispersion=False, +) +for img in ase_neb_compare.images: + img.calc = ase_calculator +# ---------------------------------- + +initial_system = ts.io.atoms_to_state( + relaxed_start_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype +) +final_system = ts.io.atoms_to_state( + relaxed_end_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype +) + +neb_workflow = TorchNEB( + model=ts_mace_model, + device=torch_sim_device, + dtype=torch_sim_dtype, + spring_constant=0.1, + n_images=5, + use_climbing_image=True, # Set as desired for the actual run + optimizer_type="ase_fire", # Set as desired for the actual run + optimizer_params={}, + trajectory_filename=traj_file_name, +) + +compare_initial_paths( + relaxed_start_atoms, relaxed_end_atoms, initial_system, final_system, neb_workflow +) + + +# --- Add Function for Manual ASE Force Calculation --- +def calculate_ase_neb_force_step0( + ase_neb_calc: ASENEB, + image_index: int, + neb_workflow: TorchNEB, + output_filename="ase_step0_debug.json", +): + """Manually calculates the ASE NEB force components for a specific + intermediate image at step 0 (after initial interpolation) and saves + the results to a JSON file. + Uses the ImprovedTangent method for consistency with torch-sim default. + """ + print(f"--- Calculating ASE NEB Debug Info (Step 0, Image Index {image_index}) ---") + debug_data = { + "step": 0, + "image_index_intermediate": image_index - 1, # 0-based index among intermediates + "image_index_absolute": image_index, # 0-based index in full list + "inputs": {}, + "outputs": {}, + "error": None, + } + + n_images = ase_neb_calc.nimages # Total number of images including endpoints + if not (0 < image_index < n_images - 1): + error_msg = f"Error: image_index {image_index} is not an intermediate image." + print(error_msg) + debug_data["error"] = error_msg + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + return + + # 1. Get initial energies and forces after interpolation + calculator attachment + try: + initial_energies_np = np.array( + [img.get_potential_energy() for img in ase_neb_calc.images] + ) + initial_forces_np = np.stack([img.get_forces() for img in ase_neb_calc.images]) + + # No need for .tolist() with MontyEncoder + debug_data["inputs"]["energies_all"] = initial_energies_np + debug_data["inputs"]["true_forces_image"] = initial_forces_np[image_index] + debug_data["inputs"]["positions_image_minus_1"] = ase_neb_calc.images[ + image_index - 1 + ].get_positions() + debug_data["inputs"]["positions_image"] = ase_neb_calc.images[ + image_index + ].get_positions() + debug_data["inputs"]["positions_image_plus_1"] = ase_neb_calc.images[ + image_index + 1 + ].get_positions() + debug_data["inputs"]["cell"] = ( + ase_neb_calc.images[image_index].get_cell().tolist() + ) + # No need for bool() conversion with MontyEncoder + debug_data["inputs"]["pbc"] = ase_neb_calc.images[image_index].get_pbc() + + except Exception as e: + error_msg = f"Error getting initial energies/forces from ASE images: {e}" + print(error_msg) + debug_data["error"] = error_msg + import traceback + + debug_data["traceback"] = traceback.format_exc() + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + return + + # 2. Setup NEB state and method objects + ase_neb_obj_for_state = ASENEB( + ase_neb_calc.images, + k=neb_workflow.spring_constant, + climb=neb_workflow.use_climbing_image, + method="improvedtangent", + ) + neb_state = NEBState(ase_neb_obj_for_state, ase_neb_calc.images, initial_energies_np) + tangent_method = ImprovedTangentMethod(ase_neb_obj_for_state) + + # 3. Calculate components for the target image_index + try: + spring1 = neb_state.spring(image_index - 1) + spring2 = neb_state.spring(image_index) + # No .tolist() needed + debug_data["outputs"]["mic_displacement_1"] = spring1.t + debug_data["outputs"]["mic_displacement_2"] = spring2.t + + # Calculate tangent + tangent_ase = tangent_method.get_tangent(neb_state, spring1, spring2, image_index) + tangent_norm_ase = np.linalg.norm(tangent_ase) + if tangent_norm_ase > 1e-15: + tangent_ase_normalized = tangent_ase / tangent_norm_ase + else: + tangent_ase_normalized = tangent_ase # Keep as zero vector + tangent_norm_final = np.linalg.norm(tangent_ase_normalized) + + # No .tolist() needed + debug_data["outputs"]["tangent_vector"] = tangent_ase_normalized + debug_data["outputs"]["tangent_norm"] = tangent_norm_final + + # Calculate perpendicular force + true_force_img = initial_forces_np[image_index] + f_true_dot_tau_ase = np.vdot(true_force_img, tangent_ase_normalized) + f_perp_ase = true_force_img - f_true_dot_tau_ase * tangent_ase_normalized + f_perp_norm = np.linalg.norm(f_perp_ase) + + # No .tolist() needed + debug_data["outputs"]["f_true_dot_tau"] = f_true_dot_tau_ase + debug_data["outputs"]["f_perp_vector"] = f_perp_ase + debug_data["outputs"]["f_perp_norm"] = f_perp_norm + + # Calculate parallel spring force + segment_lengths_all = [neb_state.spring(i).nt for i in range(n_images - 1)] + spring_mag_term = spring2.nt * spring2.k - spring1.nt * spring1.k + f_spring_par_ase = spring_mag_term * tangent_ase_normalized + f_spring_par_norm = np.linalg.norm(f_spring_par_ase) + + # No .tolist() needed + debug_data["outputs"]["segment_lengths"] = segment_lengths_all + debug_data["outputs"]["spring_force_magnitude_term"] = spring_mag_term + debug_data["outputs"]["f_spring_par_vector"] = f_spring_par_ase + debug_data["outputs"]["f_spring_par_norm"] = f_spring_par_norm + + # Calculate total NEB force (before potential climbing modification) + neb_force_ase = f_perp_ase + f_spring_par_ase + # Explicitly convert to numpy array before saving, remove .tolist() + debug_data["outputs"]["neb_force_before_climb_vector"] = np.array(neb_force_ase) + debug_data["outputs"]["neb_force_before_climb_norm"] = np.linalg.norm( + neb_force_ase + ) + + # --- Direct Debug Prints for Step 0 --- + print("\n --- DIRECT DEBUG PRINT (ASE STEP 0) ---") + print(f" f_perp_norm: {f_perp_norm}") + print(f" f_perp_vec[0]: {f_perp_ase[0]}") + print(f" spring1_length (R[{image_index}]-R[{image_index - 1}]): {spring1.nt}") + print(f" spring2_length (R[{image_index + 1}]-R[{image_index}]): {spring2.nt}") + print(f" Length Diff (spring2.nt - spring1.nt): {spring2.nt - spring1.nt}") + print(f" f_spring_par_norm: {f_spring_par_norm}") + print(f" f_spring_par_vec[0]: {f_spring_par_ase[0]}") + print(f" neb_force_before_climb_norm: {np.linalg.norm(neb_force_ase)}") + print(" ------------------------------------") + # -------------------------------------- + + # Handle climbing image modification + is_climbing = ase_neb_obj_for_state.climb and image_index == neb_state.imax + debug_data["outputs"]["is_climbing_image"] = is_climbing + debug_data["outputs"]["imax"] = int( + neb_state.imax + ) # Ensure imax is JSON serializable + + if is_climbing: + climbing_force_ase = ( + true_force_img - 2 * f_true_dot_tau_ase * tangent_ase_normalized + ) + climbing_force_norm = np.linalg.norm(climbing_force_ase) + # No .tolist() needed + debug_data["outputs"]["climbing_force_vector"] = climbing_force_ase + debug_data["outputs"]["climbing_force_norm"] = climbing_force_norm + final_force_ase = climbing_force_ase + else: + final_force_ase = neb_force_ase + + # No .tolist() needed + debug_data["outputs"]["final_neb_force_vector"] = final_force_ase + debug_data["outputs"]["final_neb_force_norm"] = np.linalg.norm(final_force_ase) + + except Exception as e: + error_msg = ( + f"Error during manual ASE force calculation for image {image_index}: {e}" + ) + print(error_msg) + debug_data["error"] = error_msg + import traceback + + debug_data["traceback"] = traceback.format_exc() + + # Write data to JSON + try: + with open(output_filename, "w") as f: + json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder + print(f"--- ASE NEB Debug Info saved to {output_filename} ---") + except Exception as e: + print(f"Error writing ASE debug info to JSON: {e}") + + +# --- Add Function for Comparing JSON/Pickle Outputs to debug the tangent force calculation --- +def compare_step0_outputs( + file_ase="ase_step0_debug.json", + file_ts="torchsim_step0_debug.pkl", + rtol=1e-5, + atol=1e-6, +): + print("\n--- Comparing Step 0 Debug Outputs (ASE JSON vs TorchSim Pickle) --- ") + try: + # Load ASE data from JSON + with open(file_ase) as f: + data_ase = json.load(f, cls=MontyDecoder) + # Load TorchSim data from Pickle + with open(file_ts, "rb") as f: # Use 'rb' for pickle + data_ts = pickle.load(f) + except FileNotFoundError as e: + print(f"Error: Could not find file {e.filename}") + return + except Exception as e: + print(f"Error loading JSON/Pickle files: {e}") + return + + # Basic checks + if data_ase.get("error") or data_ts.get("error"): + print("Comparison aborted due to error during data generation.") + print(f" ASE Error: {data_ase.get('error')}") + print(f" TS Error: {data_ts.get('error')}") + return + + if data_ase.get("step") != 0 or data_ts.get("step") != 0: + print("Warning: One or both files do not contain step 0 data.") + # Continue comparison anyway + + if data_ase.get("image_index_intermediate") != data_ts.get( + "image_index_intermediate" + ): + print("Warning: JSON files are for different intermediate image indices.") + # Continue comparison anyway + + outputs_ase = data_ase.get("outputs", {}) + outputs_ts = data_ts.get("outputs", {}) + + all_keys = set(outputs_ase.keys()) | set(outputs_ts.keys()) + mismatches = 0 + + print( + f"Comparing fields for intermediate image index: {data_ase.get('image_index_intermediate', 'N/A')}" + ) + + for key in sorted(list(all_keys)): + val_ase = outputs_ase.get(key) + val_ts = outputs_ts.get(key) + + if key not in outputs_ts: + print(f" - Key '{key}': Present in ASE, Missing in TorchSim") + mismatches += 1 + continue + if key not in outputs_ase: + print(f" - Key '{key}': Missing in ASE, Present in TorchSim") + mismatches += 1 + continue + + # --- Handle Type Conversion for Comparison --- + val_ase_comp = val_ase + val_ts_comp = val_ts + + # Convert torch tensor from pickle to numpy/scalar for comparison + if isinstance(val_ts_comp, torch.Tensor): + if val_ts_comp.ndim == 0: # Scalar tensor + val_ts_comp = val_ts_comp.item() + else: + val_ts_comp = val_ts_comp.detach().cpu().numpy() # Use detach() + # -------------------------------------------- + + # --- Debug Print for Specific Key --- + if key == "neb_force_before_climb_vector": + print( + f" DEBUG compare [{key}]: ASE[0]={np.array(val_ase_comp)[0]}, TS[0]={np.array(val_ts_comp)[0]}" + ) + # ------------------------------------ + + # --- Special Handling for imax index --- + if key == "imax": + # ASE imax is index in full list (1 to n_images-1) + # TS imax is index in intermediates (0 to n_images-2) + # Compare ASE imax with TS imax + 1 + ase_imax = int(val_ase_comp) + ts_imax_plus_1 = int(val_ts_comp) + 1 + match = ase_imax == ts_imax_plus_1 + if not match: + difference_info = f"ASE imax={ase_imax}, TS imax(adj)={ts_imax_plus_1}" + status = "Match" if match else "DIFFER" + print(f" - Key '{key:<30}': {status} {difference_info}") + if not match: + mismatches += 1 + continue # Skip rest of comparison for imax + # ------------------------------------- + + # Try numerical comparison first + match = False + difference_info = "" + try: + # Ensure they are numpy arrays for consistent comparison + # ASE data might already be numpy or list, TS data was converted above + arr_ase = np.array(val_ase_comp) + arr_ts = np.array(val_ts_comp) + + if arr_ase.shape != arr_ts.shape: + match = False + difference_info = f"Shapes differ: ASE={arr_ase.shape}, TS={arr_ts.shape}" + elif np.issubdtype(arr_ase.dtype, np.number) and np.issubdtype( + arr_ts.dtype, np.number + ): + match = np.allclose(arr_ase, arr_ts, rtol=rtol, atol=atol) + if not match: + max_abs_diff = np.max(np.abs(arr_ase - arr_ts)) + difference_info = f"Max abs diff: {max_abs_diff:.6e}" + elif arr_ase.dtype == np.bool_ and arr_ts.dtype == np.bool_: + match = np.array_equal(arr_ase, arr_ts) + if not match: + difference_info = f"Boolean values differ: ASE={arr_ase}, TS={arr_ts}" + else: # Fallback for other types (e.g., strings if they were arrays) + match = np.array_equal(arr_ase, arr_ts) + if not match: + difference_info = "Non-numerical array values differ" + + except (TypeError, ValueError): + # Fallback to direct comparison for non-array types or incompatible arrays + try: + if isinstance(val_ase_comp, (float, int)) and isinstance( + val_ts_comp, (float, int) + ): + match = np.isclose(val_ase_comp, val_ts_comp, rtol=rtol, atol=atol) + if not match: + difference_info = f"Diff: {abs(val_ase_comp - val_ts_comp):.6e}" + elif type(val_ase_comp) == type(val_ts_comp): + match = val_ase_comp == val_ts_comp + if not match: + difference_info = ( + f"Values differ: ASE='{val_ase_comp}', TS='{val_ts_comp}'" + ) + else: + # Types should ideally match after conversion, but check just in case + match = False + difference_info = f"Types differ after conversion: ASE={type(val_ase_comp)}, TS={type(val_ts_comp)}" + except Exception: + match = False + + status = "Match" if match else "DIFFER" # Pad DIFFER for alignment + print(f" - Key '{key:<30}': {status} {difference_info}") + if not match: + mismatches += 1 + + if mismatches == 0: + print("\nAll compared output fields match.") + else: + print(f"\nFound {mismatches} mismatch(es) in output fields.") + print("--- End Comparison --- ") + + +# ------------------------------------------------- + + +# --- Add Function to Print Pickle Structure --- +def print_pickle_structure(filename="torchsim_step0_debug.pkl"): + print(f"\n--- Structure of Pickle File: {filename} --- ") + try: + with open(filename, "rb") as f: + data = pickle.load(f) + except FileNotFoundError: + print(f"Error: File not found: {filename}") + return + except Exception as e: + print(f"Error loading pickle file: {e}") + return + + if not isinstance(data, dict): + print(f"Loaded data is not a dictionary (Type: {type(data)})") + return + + print(f"Keys: {list(data.keys())}") + for key, value in data.items(): + if isinstance(value, dict): + print(f" {key}:") + for subkey, subvalue in value.items(): + val_type = type(subvalue) + val_shape = getattr(subvalue, "shape", "N/A") + # Add dtype for tensors + val_dtype = getattr(subvalue, "dtype", "N/A") + print( + f" - {subkey:<30}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}" + ) + else: + val_type = type(value) + val_shape = getattr(value, "shape", "N/A") + val_dtype = getattr(value, "dtype", "N/A") + print(f" {key:<32}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}") + print("--- End Pickle Structure --- ") + + +# -------------------------------------------- + +# --- Perform manual ASE force calculation for step 0 --- +debug_ase_img_index = ( + n_intermediate_images_ase // 2 + 1 +) # Index in the full list (0 to n_images+1) +calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) +# ------------------------------------------------------ + +print("\nStarting torch-sim NEB optimization...") +final_path_gd = neb_workflow.run( + initial_system=initial_system, + final_system=final_system, + max_steps=100, # Keep increased steps for now + fmax=0.05, +) +print("Finished torch-sim NEB optimization.") + +# Check if it converged and plot results +results = ts_mace_model( + dict( + positions=final_path_gd.positions, + cell=final_path_gd.cell, + atomic_numbers=final_path_gd.atomic_numbers, + system_idx=final_path_gd.system_idx, + pbc=True, + ) +) + +energies = results["energy"].tolist() + +# Including the energies from the ASE NEB calculation for comparison +# ase_energies = [0.0, 0.154541015625, 0.6151123046875, 0.8592529296875, 0.8148193359375, 0.5965576171875, 0.47705078125] + +ase_neb_calc = ase_neb(relaxed_start_atoms, relaxed_end_atoms, nimages=5) +ase_energies = [image.get_potential_energy() for image in ase_neb_calc.images] +scaled_ase_energies = [e - ase_energies[0] for e in ase_energies] + + +scaled_energies = [e - energies[0] for e in energies] + +print(scaled_energies) +torch_sim_barrier = max(scaled_energies) - scaled_energies[0] +ase_barrier = max(scaled_ase_energies) - scaled_ase_energies[0] + +# Create normalized reaction coordinates (0 to 1) for both datasets +torch_sim_coords = np.linspace(0, 1, len(scaled_energies)) +ase_coords = np.linspace(0, 1, len(scaled_ase_energies)) + +# Create a common x-axis with 100 points for smoother plotting +common_coords = np.linspace(0, 1, 100) + +# Interpolate both energy profiles to the common coordinate system +torch_sim_interp = np.interp(common_coords, torch_sim_coords, scaled_energies) +ase_interp = np.interp(common_coords, ase_coords, scaled_ase_energies) + +# --- Print Pickle Structure to Verify --- +# print_pickle_structure() +# ------------------------------------- + +# --- Compare Step 0 Debug Outputs for compute_tangent at step 0--- +# compare_step0_outputs() # Use the updated function name +# ------------------------------------ + + +# --- Plot the energy profiles --- +plt.plot(common_coords, torch_sim_interp, label="torch-sim") +plt.plot(common_coords, ase_interp, label="ASE") +plt.xlabel("Reaction Coordinate") +plt.ylabel("Energy (eV)") +plt.title( + f"ASE Barrier = {ase_barrier:.4f} eV, torch-sim Barrier = {torch_sim_barrier:.4f} eV, Difference = {torch_sim_barrier - ase_barrier:.4f} eV" +) +plt.legend() +plt.show() +# ------------------------------------ + + +# --- Function to Inspect HDF5 File Structure --- +def inspect_hdf5(filename): + print(f"\n--- Inspecting HDF5 File: {filename} ---") + try: + with h5py.File(filename, "r") as f: + + def print_attrs(name, obj): + print(f" Path: /{name}") + if isinstance(obj, h5py.Dataset): + print(" Type: Dataset") + print(f" Shape: {obj.shape}") + print(f" Dtype: {obj.dtype}") + # Optionally print a small slice of data + # try: + # print(f" Data sample: {obj[0:min(2, obj.shape[0])]}") + # except Exception as e: + # print(f" Could not read data sample: {e}") + elif isinstance(obj, h5py.Group): + print(" Type: Group") + print(f" Attributes: {dict(obj.attrs)}") + + f.visititems(print_attrs) + except FileNotFoundError: + print(f"Error: File not found: {filename}") + except Exception as e: + print(f"Error inspecting HDF5 file: {e}") + print("--- End HDF5 Inspection ---") + + +# ---------------------------------------------- + + +# --- Analyze Optimizer Convergence --- +def analyze_convergence(ts_traj_file, ase_fmax_csv_file): + print("\n--- Analyzing Optimizer Convergence ---") + max_force_ts = [] + max_force_ase = [] + + # Analyze torch-sim trajectory + try: + with h5py.File(ts_traj_file, "r") as f: + if "data/neb_forces" not in f or "data/image_indices" not in f: + raise ValueError( + "HDF5 file missing '/data/neb_forces' or '/data/image_indices' datasets." + ) + + # Data is under /data group, steps are the first dimension + neb_forces_dset = f["/data/neb_forces"] + image_indices_dset = f["/data/image_indices"] + + n_steps = neb_forces_dset.shape[0] + # Read static image indices (take the first slice) + image_indices = image_indices_dset[0, :] + + # Infer dimensions + n_images_total = len(np.unique(image_indices)) + n_atoms_total = len(image_indices) + if neb_forces_dset.shape[1] != n_atoms_total: + raise ValueError( + f"Mismatch between image_indices length ({n_atoms_total}) and neb_forces second dimension ({neb_forces_dset.shape[1]})" + ) + + n_atoms_per_image = n_atoms_total // n_images_total + print( + f"TorchSim Traj: {n_steps} steps, {n_images_total} total images, {n_atoms_per_image} atoms/image." + ) + + for step in range(n_steps): + # Access forces for the current step from the first dimension + neb_forces = torch.from_numpy(neb_forces_dset[step, :, :]) + + # Select forces only for intermediate images (index 1 to n_images_total - 2) + intermediate_mask = (image_indices > 0) & ( + image_indices < n_images_total - 1 + ) + forces_intermediate = neb_forces[intermediate_mask] + if forces_intermediate.numel() > 0: + max_comp = torch.max(torch.abs(forces_intermediate)).item() + max_force_ts.append(max_comp) + else: + max_force_ts.append(0.0) # Or handle error/empty case + + except Exception as e: + print(f"Error reading torch-sim trajectory {ts_traj_file}: {e}") + + # Read ASE fmax data from CSV + try: + # Use numpy.loadtxt to read the 2nd column (index 1) from the CSV + # Assuming tab delimiter, skipping header row + ase_data = np.loadtxt(ase_fmax_csv_file, delimiter="\t", skiprows=1, usecols=(1,)) + max_force_ase = ase_data.tolist() # Convert numpy array to list + print(f"Read {len(max_force_ase)} fmax values from {ase_fmax_csv_file}") + except Exception as e: + print(f"Error reading ASE fmax CSV file {ase_fmax_csv_file}: {e}") + + # Plotting + if max_force_ts or max_force_ase: + plt.figure() + if max_force_ts: + plt.plot( + range(len(max_force_ts)), + max_force_ts, + label="torch-sim (ase_fire)", + marker=".", + ) + if max_force_ase: + plt.plot( + range(len(max_force_ase)), max_force_ase, label="ASE (FIRE)", marker="." + ) + plt.xlabel("Optimization Step") + plt.ylabel("Max Abs Force Component (eV/Ang)") + plt.title("Optimizer Convergence Comparison") + plt.legend() + plt.grid(True) + plt.yscale("log") # Log scale often helpful for forces + plt.show() + else: + print("No force data extracted to plot convergence.") + + +# inspect_hdf5(traj_file_name) +analyze_convergence(traj_file_name, "ase_fmax_convergence.csv") +# --------------------------------- + +# --- Debugging Functions (Keep for reference) --- +# def calculate_ase_neb_force_step0(...): ... +# def compare_step0_outputs(...): ... +# def print_pickle_structure(...): ... + + +# --- Call Step 0 Debug Functions (Commented out) --- +# # Perform manual ASE force calculation for step 0 +debug_ase_img_index = n_intermediate_images_ase // 2 + 1 +calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) + +# # Print Pickle Structure to Verify +print_pickle_structure() + +# # Compare Step 0 Debug Outputs +compare_step0_outputs() +# -------------------------------------------------- diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py new file mode 100644 index 000000000..8be7027be --- /dev/null +++ b/torch_sim/workflows/neb.py @@ -0,0 +1,887 @@ +"""Nudged Elastic Band (NEB) workflow. + +This module implements the Nudged Elastic Band method for finding minimum energy +paths between two given atomic configurations. +""" + +import inspect +import logging +import os # Import os for fsync +import pickle # Import pickle +from collections.abc import Callable +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Any, Literal + +import torch + +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import ( + CellFireState, + FireState, + OptimState, + fire_init, + fire_step, + gradient_descent_init, + gradient_descent_step, +) +from torch_sim.optimizers.cell_filters import CellFilter +from torch_sim.state import ( + SimState, + concatenate_states, + initialize_state, +) +from torch_sim.trajectory import TorchSimTrajectory +from torch_sim.transforms import minimum_image_displacement +from torch_sim.typing import StateLike + + +logger = logging.getLogger(__name__) + +# Add epsilon for numerical stability +_EPS = torch.finfo(torch.float64).eps + + +def _extract_kwargs_from_params( + params: dict[str, Any], func: Callable[..., Any], exclude: set[str] | None = None +) -> dict[str, Any]: + """Extract kwargs from params dict that match function signature. + + Args: + params: Dictionary of parameters to filter + func: Function to extract parameters for + exclude: Set of parameter names to exclude (e.g., 'state', 'model') + + Returns: + Dictionary of parameters that match the function signature + """ + exclude = exclude or {"state", "model"} + sig = inspect.signature(func) + return { + k: v + for k, v in params.items() + if k in sig.parameters and k not in exclude + } + + +@dataclass +class _OptimizerConfig: + """Configuration for an optimizer type.""" + + init_fn: Callable[..., Any] + step_fn: Callable[..., Any] + state_type: type + init_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + step_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None + + +# Registry of optimizer configurations +_OPTIMIZER_REGISTRY: dict[str, _OptimizerConfig] = { + "fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=FireState, + ), + "frechet_cell_fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=CellFireState, + init_kwargs_modifier=lambda kwargs: {**kwargs, "cell_filter": CellFilter.frechet}, + ), + "gd": _OptimizerConfig( + init_fn=gradient_descent_init, + step_fn=gradient_descent_step, + state_type=OptimState, + step_kwargs_modifier=lambda kwargs: ( + kwargs if "pos_lr" in kwargs else {**kwargs, "pos_lr": kwargs.get("lr", 0.01)} + ), + ), + "ase_fire": _OptimizerConfig( + init_fn=fire_init, + step_fn=fire_step, + state_type=FireState, + init_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + step_kwargs_modifier=lambda kwargs: ( + kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} + ), + ), +} + + +@dataclass +class NEB: + """Nudged Elastic Band (NEB) optimizer. + + Finds the minimum energy path (MEP) between an initial and final state using + the NEB algorithm. + + Attributes: + model: The energy/force model (e.g., MACE) wrapped in a ModelInterface. + n_images: Number of intermediate images between initial and final states. + spring_constant: Spring constant connecting adjacent images (eV/Ang^2). + use_climbing_image: Whether to use a climbing image. + optimizer_type: Type of optimizer to use. + optimizer_params: Parameters for the chosen optimizer. + trajectory_filename: Optional filename for saving the NEB trajectory. + device: Computation device (e.g., 'cpu', 'cuda'). If None, uses model device. + dtype: Computation data type (e.g., torch.float32). If None, uses model dtype. + """ + + model: ModelInterface + n_images: int + spring_constant: float = 0.1 # eV/Ang^2, typical ASE default + use_climbing_image: bool = False + optimizer_type: Literal["fire", "gd", "frechet_cell_fire", "ase_fire"] = "fire" + optimizer_params: dict[str, Any] = field(default_factory=dict) + trajectory_filename: str | None = None + device: torch.device | None = None + dtype: torch.dtype | None = None + + def __post_init__(self) -> None: + """Initializes device, dtype, and optimizer functions after dataclass creation.""" + if self.device is None: + self.device = self.model.device + if self.dtype is None: + self.dtype = self.model.dtype + + # Initialize variable to store step 0 debug output + self._step0_debug_output = None + + # Get optimizer configuration from registry + if self.optimizer_type not in _OPTIMIZER_REGISTRY: + raise ValueError( + f"Unsupported optimizer_type: {self.optimizer_type}. " + f"Supported types: {list(_OPTIMIZER_REGISTRY.keys())}" + ) + + config = _OPTIMIZER_REGISTRY[self.optimizer_type] + self._init_fn = config.init_fn + self._step_fn = config.step_fn + self._OptimizerStateType = config.state_type + + # Automatically extract kwargs from optimizer_params based on function signatures + # For init: exclude 'state' and 'model' (positional args) + # For step: exclude 'state' and 'model' (positional args) + init_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.init_fn, exclude={"state", "model"} + ) + step_kwargs = _extract_kwargs_from_params( + self.optimizer_params, config.step_fn, exclude={"state", "model"} + ) + + # Apply modifiers if provided (for special cases like cell_filter, defaults, etc.) + if config.init_kwargs_modifier: + init_kwargs = config.init_kwargs_modifier(init_kwargs) + if config.step_kwargs_modifier: + step_kwargs = config.step_kwargs_modifier(step_kwargs) + + self._init_kwargs = init_kwargs + self._step_kwargs = step_kwargs + + def _interpolate_path( + self, initial_state: SimState, final_state: SimState + ) -> SimState: + """Linearly interpolate the initial path between states using MIC. + + Generates `n_images` intermediate states between the initial and final states + by linear interpolation of atomic positions, respecting periodic boundary + conditions via the Minimum Image Convention (MIC). + + Args: + initial_state (SimState): The starting SimState (must be single-batch). + final_state (SimState): The ending SimState (must be single-batch). + + Returns: + SimState: A single SimState containing all interpolated intermediate + images, batched together. The batch index corresponds to the image + index (0 to n_images-1). + + Raises: + ValueError: If initial and final states are incompatible (e.g., different + number of atoms, atom types, PBC settings, or if they are not + single-batch states). + """ + # --- Input Validation --- + if initial_state.n_systems != 1 or final_state.n_systems != 1: + raise ValueError("Initial and final states must be single-system SimStates.") + if initial_state.n_atoms != final_state.n_atoms: + raise ValueError( + f"Initial ({initial_state.n_atoms}) and final ({final_state.n_atoms}) " + "states must have the same number of atoms." + ) + if not torch.equal(initial_state.atomic_numbers, final_state.atomic_numbers): + # Comparing floats might be tricky, but atomic numbers should be exact + raise ValueError("Initial and final states must have the same atom types.") + # Compare PBC values properly (can be bool, list, or tensor) + pbc_match = False + if isinstance(initial_state.pbc, torch.Tensor) and isinstance(final_state.pbc, torch.Tensor): + pbc_match = torch.equal(initial_state.pbc, final_state.pbc) + elif isinstance(initial_state.pbc, torch.Tensor) or isinstance(final_state.pbc, torch.Tensor): + # One is tensor, one is not - convert both to tensors for comparison + initial_pbc_tensor = ( + initial_state.pbc + if isinstance(initial_state.pbc, torch.Tensor) + else torch.tensor(initial_state.pbc, device=initial_state.device) + ) + final_pbc_tensor = ( + final_state.pbc + if isinstance(final_state.pbc, torch.Tensor) + else torch.tensor(final_state.pbc, device=final_state.device) + ) + pbc_match = torch.equal(initial_pbc_tensor, final_pbc_tensor) + else: + # Both are bools or lists + pbc_match = initial_state.pbc == final_state.pbc + if not pbc_match: + # TODO: Could potentially support different PBCs, but complex for NEB. + raise ValueError("Initial and final states must have the same PBC setting.") + # For fixed-cell NEB, cells should ideally be identical. Warn if not? + # if not torch.allclose(initial_state.cell, final_state.cell): + + n_atoms_per_image = initial_state.n_atoms + + # --- Interpolation --- + initial_pos = initial_state.positions + final_pos = final_state.positions + + # Calculate displacement using Minimum Image Convention + displacement = minimum_image_displacement( + dr=final_pos - initial_pos, + cell=initial_state.cell[0], # Use cell from initial state + pbc=initial_state.pbc, + ) + # Ensure shape is correct [n_atoms, 3] + displacement = displacement.reshape(n_atoms_per_image, 3) + + # Generate interpolation factors (e.g., for n_images=3: 0.25, 0.5, 0.75) + factors = torch.linspace( + 0.0, 1.0, steps=self.n_images + 2, device=self.device, dtype=self.dtype + )[1:-1] # Exclude 0.0 and 1.0 # Ensure dtype + factors = factors.view(-1, 1, 1) # Shape: [n_images, 1, 1] + + # Calculate interpolated positions: initial + factor * displacement + # Broadcasting: [N_atoms, 3] + [N_images, 1, 1] * [N_atoms, 3] -> [N_images, N_atoms, 3] + interpolated_pos = initial_pos.unsqueeze(0) + factors * displacement.unsqueeze(0) + + # Reshape to [n_images * n_atoms_per_image, 3] + all_positions = interpolated_pos.reshape(-1, 3) + + # --- Create Batched State --- + # Repeat other attributes for each image + all_atomic_numbers = initial_state.atomic_numbers.repeat(self.n_images) + all_masses = initial_state.masses.repeat(self.n_images) + # Use initial state's cell, repeated for each image + all_cells = initial_state.cell.repeat( + self.n_images, 1, 1 + ) # Shape: [n_images, 3, 3] + + # Create system_idx tensor: [0, 0, ..., 1, 1, ..., n_images-1, ...] + system_indices = torch.arange(self.n_images, device=self.device, dtype=torch.int64) + all_system_idx = torch.repeat_interleave(system_indices, repeats=n_atoms_per_image) + + return SimState( + positions=all_positions, + atomic_numbers=all_atomic_numbers, + masses=all_masses, + cell=all_cells, + pbc=initial_state.pbc, + system_idx=all_system_idx, + ) + + def _compute_tangents( + self, + all_pos: torch.Tensor, # Shape: [n_total_images, n_atoms, 3] + all_energies: torch.Tensor, # Shape: [n_total_images] + cell: torch.Tensor, # Shape: [3, 3] + *, # Make pbc keyword-only + pbc: bool, + ) -> torch.Tensor: + """Compute normalized tangent vectors for intermediate NEB images. + + Implements the improved tangent estimate of Henkelman and Jónsson (2000) + to determine the local tangent direction at each intermediate image based + on the positions and energies of its neighbors. + + Args: + all_pos (torch.Tensor): Atomic configurations for all images in the path + (initial + intermediate + final), shape [n_total_images, n_atoms, 3]. + all_energies (torch.Tensor): Potential energy of each image, shape + [n_total_images]. + cell (torch.Tensor): Unit cell vectors (shape [3, 3]), assumed constant + for the path. + pbc (bool): Flag indicating if periodic boundary conditions are active. + + Returns: + torch.Tensor: Normalized local tangent vectors for the intermediate + images only, shape [n_images, n_atoms, 3]. Tangents are zero for + numerically identical adjacent images. + """ + n_total_images, n_atoms_per_image, _ = all_pos.shape + n_intermediate_images = n_total_images - 2 + device = all_pos.device + dtype = all_pos.dtype + + # Initialize tangents for intermediate images only + tangents = torch.zeros( + (n_intermediate_images, n_atoms_per_image, 3), + device=device, + dtype=self.dtype, # Use self.dtype + ) + + # Calculate displacements between adjacent images using MIC + # dR_forward[i] = R_{i+1} - R_i + displacements = minimum_image_displacement( + dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc + ) + # Ensure shape is correct after MIC if needed + displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) + + # Energy differences V_{i+1} - V_i + dE_forward = all_energies[1:] - all_energies[:-1] # Shape: [n_total_images - 1] + + # Compute tangents for intermediate images (indices 1 to N in all_pos) + for i in range(n_intermediate_images): + img_idx = i + 1 # Index in all_pos, all_energies + + # Displacements adjacent to image `img_idx` + # Note: displacements[k] is R_{k+1} - R_k + dR_plus = displacements[img_idx] # R_{i+1} - R_i (where i = img_idx) + dR_minus = displacements[img_idx - 1] # R_i - R_{i-1} (where i = img_idx) + + # Energy differences adjacent to image `img_idx` + dE_plus = dE_forward[img_idx] # V_{i+1} - V_i + dE_minus = dE_forward[img_idx - 1] # V_i - V_{i-1} + + # Select tangent based on energy profile (Henkelman & Jónsson criteria) + tangent_i = torch.zeros_like(dR_plus) + + # Condition 1: Ascending segment (minimum) V_{i+1}>V_i and V_i>V_{i-1} => dE_plus>0 and dE_minus>0 + if dE_plus > 0 and dE_minus > 0: + tangent_i = ( + dR_plus # ASE uses forward difference (dR_plus = R[i+1] - R[i]) + ) + + # Condition 2: Descending segment (maximum) V_{i+1} dE_plus<0 and dE_minus<0 + elif ( + dE_plus < 0 and dE_minus < 0 + ): # Check if dE_minus comparison is correct (<0 vs >0) + # tangent_i = dR_plus if abs(dE_plus) < abs(dE_minus) else dR_minus # Old complex version + # ASE logic: if E[i+1] < E[i] < E[i-1], tangent = dR_minus (spring1.t) -> Mismatch? + # Let's assume torch-sim should match ASE exactly: + tangent_i = ( + dR_minus # ASE uses backward difference (dR_minus = R[i] - R[i-1]) + ) + + # Condition 3: Other cases (weighted average in ASE) + else: + # Implement ASE's weighting logic precisely + # Note: ASE uses absolute values for deltavmax/min calculation + abs_dE_plus = torch.abs(dE_plus) + abs_dE_minus = torch.abs(dE_minus) + + deltavmax = torch.maximum(abs_dE_plus, abs_dE_minus) + deltavmin = torch.minimum(abs_dE_plus, abs_dE_minus) + + # Check E[i+1] vs E[i-1] + # E[i+1] - E[i-1] = dE_plus + dE_minus + if (dE_plus + dE_minus) > 0: # E[i+1] > E[i-1] + tangent_i = dR_plus * deltavmax + dR_minus * deltavmin + else: # E[i+1] <= E[i-1] + tangent_i = dR_plus * deltavmin + dR_minus * deltavmax + + # Normalize the tangent vector *within* the loop + norm_i = torch.linalg.norm(tangent_i) + if norm_i > _EPS: + tangents[i] = tangent_i / norm_i + # else: tangent remains zero if norm is too small + + return tangents + + def _calculate_neb_forces( + self, + path_state: SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + step: int, + ) -> tuple[torch.Tensor, dict | None]: # Return forces and optional debug data + """Calculate the NEB forces for intermediate images. + + The NEB force is composed of the true force perpendicular to the path tangent + and the spring force parallel to the path tangent. Handles climbing image + force modification if enabled. + + Args: + path_state (SimState): SimState containing the full path (initial + + intermediate + final images). Batches are assumed to be ordered. + true_forces (torch.Tensor): Forces from the potential energy model for + the *intermediate* images only, shape [n_movable_atoms, 3]. + true_energies (torch.Tensor): Potential energies for the *intermediate* + images only, shape [n_images]. + initial_energy (torch.Tensor): Potential energy of the initial state + (scalar tensor). + final_energy (torch.Tensor): Potential energy of the final state + (scalar tensor). + step (int): Current optimization step number (used for climbing image delay). + + Returns: + torch.Tensor: Calculated NEB forces for the intermediate images, ready to + be passed to the optimizer, shape [n_movable_atoms, 3]. + """ + n_total_images = path_state.n_systems + n_intermediate_images = n_total_images - 2 + assert n_intermediate_images == self.n_images + n_atoms_per_image = path_state.n_atoms // n_total_images + + # --- Reshape inputs --- + # Positions for all images: [n_total_images, n_atoms, 3] + all_pos = path_state.positions.reshape(n_total_images, n_atoms_per_image, 3) + # True forces for intermediate images: [n_images, n_atoms, 3] + true_forces_reshaped = true_forces.reshape( + n_intermediate_images, n_atoms_per_image, 3 + ) + # Cell vectors (assuming fixed cell for now, take from first batch) + cell = path_state.cell[0] # Shape [3, 3] + # Convert pbc to bool if it's a tensor (for _compute_tangents) + if isinstance(path_state.pbc, torch.Tensor): + pbc_bool: bool = bool(path_state.pbc.any().item()) # True if any dimension has PBC + elif isinstance(path_state.pbc, bool): + pbc_bool = path_state.pbc + elif isinstance(path_state.pbc, list): + pbc_bool = bool(any(path_state.pbc)) + else: + pbc_bool = True + pbc = path_state.pbc # Keep original for minimum_image_displacement + + # --- Get Energies for Tangent Calculation --- + all_energies = torch.cat( + [ + initial_energy.unsqueeze(0), + true_energies, + final_energy.unsqueeze(0), + ] + ) + + # --- Setup for Debugging Step 0 --- + log_step_0 = step == 0 + debug_img_idx = ( + n_intermediate_images // 2 + ) # Index within intermediates (0 to n_images-1) + debug_img_idx_all = debug_img_idx + 1 # Index within all_pos (0 to n_images+1) + debug_data_ts = {} # Initialize debug dict + + if log_step_0: + debug_data_ts = { + "step": 0, + "image_index_intermediate": debug_img_idx, + "image_index_absolute": debug_img_idx_all, + "inputs": {}, + "outputs": {}, + "error": None, + } + debug_data_ts["inputs"]["energies_all"] = all_energies # Monty handles tensor + debug_data_ts["inputs"]["cell"] = cell + debug_data_ts["inputs"]["pbc"] = pbc_bool # Store Python bool + debug_data_ts["inputs"]["positions_image_minus_1"] = all_pos[ + debug_img_idx_all - 1 + ] + debug_data_ts["inputs"]["positions_image"] = all_pos[debug_img_idx_all] + debug_data_ts["inputs"]["positions_image_plus_1"] = all_pos[ + debug_img_idx_all + 1 + ] + debug_data_ts["inputs"]["true_forces_image"] = true_forces_reshaped[ + debug_img_idx + ] + + # --- Calculate Tangents (tau) using the improved method --- + # tangents shape: [n_images, n_atoms, 3] + tangents = self._compute_tangents(all_pos, all_energies, cell, pbc=pbc_bool) + logger.debug( + f" Step {step}: Tangent norms per image: {torch.linalg.norm(tangents, dim=(-1, -2))}" + ) + if log_step_0: + # Note: ASE tangent might not be normalized if norm is ~0, TS tangent should be. + tangent_img = tangents[debug_img_idx] + tangent_norm_img = torch.linalg.norm(tangent_img) + debug_data_ts["outputs"]["tangent_vector"] = tangent_img + debug_data_ts["outputs"]["tangent_norm"] = tangent_norm_img + + # --- Calculate Displacements for Spring Force --- + # Recalculate here or reuse from _compute_tangents if efficient + displacements = minimum_image_displacement( + dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc + ) + displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) + if log_step_0: + # Save displacements relevant to the middle image's tangent/spring calculation + debug_data_ts["outputs"]["mic_displacement_1"] = displacements[ + debug_img_idx_all - 1 + ] # R(i) - R(i-1) + debug_data_ts["outputs"]["mic_displacement_2"] = displacements[ + debug_img_idx_all + ] # R(i+1) - R(i) + + # --- Calculate NEB Force Components --- + + # 1. Perpendicular component of true force + # F_perp = F_true - (F_true . tau) * tau + # Dot product (sum over atoms and dims): [n_images] + F_true_dot_tau = (true_forces_reshaped * tangents).sum(dim=(-1, -2), keepdim=True) + F_perp = true_forces_reshaped - F_true_dot_tau * tangents + logger.debug( + f" Step {step}: F_perp norms per image: {torch.linalg.norm(F_perp, dim=(-1, -2))}" + ) + if log_step_0: + f_perp_img = F_perp[debug_img_idx] + f_perp_norm_img = torch.linalg.norm(f_perp_img) + debug_data_ts["outputs"]["f_true_dot_tau"] = F_true_dot_tau[ + debug_img_idx + ].item() # scalar + debug_data_ts["outputs"]["f_perp_vector"] = f_perp_img + debug_data_ts["outputs"]["f_perp_norm"] = f_perp_norm_img + + # 2. Parallel component of spring force + # F_spring_par = k * (|R_{i+1}-R_i| - |R_i-R_{i-1}|) * tau_i + # Segment lengths (scalar magnitude per segment): [n_images+1] + segment_lengths = torch.linalg.norm( + displacements, dim=(-1, -2) + ) # Cleaner way [n_total_images-1] + # Spring force magnitude (scalar per intermediate image): [n_images] + F_spring_mag = self.spring_constant * (segment_lengths[1:] - segment_lengths[:-1]) + # Project onto tangent: [n_images, 1, 1] -> [n_images, n_atoms, 3] + F_spring_par = F_spring_mag.view(-1, 1, 1) * tangents + logger.debug( + f" Step {step}: F_spring_par norms per image: {torch.linalg.norm(F_spring_par, dim=(-1, -2))}" + ) + if log_step_0: + f_spring_par_img = F_spring_par[debug_img_idx] + f_spring_par_norm_img = torch.linalg.norm(f_spring_par_img) + debug_data_ts["outputs"]["segment_lengths"] = segment_lengths # Full list + debug_data_ts["outputs"]["spring_force_magnitude_term"] = F_spring_mag[ + debug_img_idx + ].item() # scalar + debug_data_ts["outputs"]["f_spring_par_vector"] = f_spring_par_img + debug_data_ts["outputs"]["f_spring_par_norm"] = f_spring_par_norm_img + + # --- Combine Components for NEB Force --- + neb_forces = F_perp + F_spring_par + if log_step_0: + # --- Direct Debug Logs for Step 0 --- + f_perp_img = F_perp[debug_img_idx] + f_spring_par_img = F_spring_par[debug_img_idx] + neb_force_img = neb_forces[debug_img_idx] + logger.debug(" --- DIRECT DEBUG LOG (TORCH-SIM STEP 0) ---") + logger.debug(f" f_perp_norm: {torch.linalg.norm(f_perp_img)}") + logger.debug(f" f_perp_vec[0]: {f_perp_img[0]}") + # segment_lengths shape: [n_total_images - 1] + # segment_lengths[debug_img_idx] corresponds to spring2 length + # segment_lengths[debug_img_idx-1] corresponds to spring1 length + len1 = segment_lengths[debug_img_idx - 1] + len2 = segment_lengths[debug_img_idx] + len_diff = len2 - len1 + logger.debug( + f" spring1_length (R[{debug_img_idx_all}]-R[{debug_img_idx_all - 1}]): {len1}" + ) + logger.debug( + f" spring2_length (R[{debug_img_idx_all + 1}]-R[{debug_img_idx_all}]): {len2}" + ) + logger.debug(f" Length Diff (len2 - len1): {len_diff}") + logger.debug(f" f_spring_par_norm: {torch.linalg.norm(f_spring_par_img)}") + logger.debug(f" f_spring_par_vec[0]: {f_spring_par_img[0]}") + logger.debug( + f" neb_force_before_climb_norm: {torch.linalg.norm(neb_force_img)}" + ) + # -------------------------------------- + # Store a *copy* detached from the graph to prevent modification by climbing image logic + debug_data_ts["outputs"]["neb_force_before_climb_vector"] = ( + neb_forces[debug_img_idx].clone().detach() + ) + debug_data_ts["outputs"]["neb_force_before_climb_norm"] = torch.linalg.norm( + neb_forces[debug_img_idx] + ) # Norm calculation is fine + + # --- Log the vector right before it would be saved --- + logger.debug( + f" Value assigned to debug_data[neb_force_before_climb_vector][0]: {neb_forces[debug_img_idx][0]}" + ) + # ----------------------------------------------------- + + # --- Handle Climbing Image --- + climbing_delay_steps = 10 # Example value + if ( + self.use_climbing_image and n_intermediate_images > 0 + ): # and step >= climbing_delay_steps: # Check step number - REMOVED DELAY + # Find index of highest energy image among intermediates + climbing_image_idx = torch.argmax( + true_energies + ).item() # Index from 0 to n_images-1 + # Calculate the climbing force: F_climb = F_true - 2 * (F_true . tau) * tau + F_climb = true_forces_reshaped[climbing_image_idx] - ( + 2 * F_true_dot_tau[climbing_image_idx] * tangents[climbing_image_idx] + ) + # Replace the NEB force for the climbing image with F_climb + # This overwrites the spring force component for this image, as required. + neb_forces[climbing_image_idx] = F_climb + logger.debug( + f" Step {step}: Climbing image index: {climbing_image_idx}, " + f"Climbing Force Norm: {torch.linalg.norm(F_climb)}" + ) + if log_step_0: + debug_data_ts["outputs"]["is_climbing_image"] = ( + climbing_image_idx == debug_img_idx + ) + debug_data_ts["outputs"]["imax"] = climbing_image_idx + debug_data_ts["outputs"]["climbing_force_vector"] = neb_forces[ + climbing_image_idx + ] + debug_data_ts["outputs"]["climbing_force_norm"] = torch.linalg.norm( + neb_forces[climbing_image_idx] + ) + + # --- Logging (Optional) --- + # logger.debug( + # " Max True Force Mag: " + # f"{torch.linalg.norm(true_forces_reshaped, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max F_perp Mag: " + # f"{torch.linalg.norm(F_perp, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max F_spring_par Mag: " + # f"{torch.linalg.norm(F_spring_par, dim=(-1,-2)).max().item():.4f}" + # ) + # logger.debug( + # " Max NEB Force Mag: " + # f"{torch.linalg.norm(neb_forces, dim=(-1,-2)).max().item():.4f}" + # ) + logger.debug( + f" Step {step}: NEB force norms per image: {torch.linalg.norm(neb_forces, dim=(-1, -2))}" + ) + logger.debug(f" Step {step}: Intermediate energies: {true_energies}") + if log_step_0 and not ( + self.use_climbing_image and climbing_image_idx == debug_img_idx + ): # Avoid logging twice if climbing image was logged + # If not the climbing image, the final force is the one before modification + pass # Already stored neb_force_before_climb + + if log_step_0: + debug_data_ts["outputs"]["final_neb_force_vector"] = neb_forces[debug_img_idx] + debug_data_ts["outputs"]["final_neb_force_norm"] = torch.linalg.norm( + neb_forces[debug_img_idx] + ) + + # --- Reshape output --- + final_neb_forces = neb_forces.reshape(-1, 3) # [n_movable_atoms, 3] + + # Return forces and the debug dictionary if step 0 + return final_neb_forces, debug_data_ts if log_step_0 else None + + def run( + self, + initial_system: StateLike, + final_system: StateLike, + max_steps: int = 100, + fmax: float = 0.05, + # TODO: add convergence criteria, batching options, output frequency etc. + ) -> SimState: + """Run the Nudged Elastic Band optimization. + + Optimizes the path between the initial and final systems to find the + Minimum Energy Path (MEP). + + Args: + initial_system (StateLike): The starting configuration (can be ASE Atoms, + SimState, or other compatible format recognized by initialize_state). + final_system (StateLike): The ending configuration. + max_steps (int): Maximum number of optimization steps allowed. + fmax (float): Convergence criterion based on the maximum NEB force component + acting on any single atom across all intermediate images (in eV/Ang). + + Returns: + SimState: The final optimized NEB path, including the initial, + intermediate, and final images, concatenated into a single SimState. + SimState: The final optimized NEB path, including the initial, + intermediate, and final images, concatenated into a single SimState. + """ + logger.info("Starting NEB optimization") + + # Reset step 0 debug output storage for this run + self._step0_debug_output = None + + # 1. Initialize initial and final states + initial_state = initialize_state(initial_system, self.device, self.dtype) + final_state = initialize_state(final_system, self.device, self.dtype) + # TODO: Add checks (e.g., same number of atoms, atom types) + # Ensure endpoints are single-system SimStates + # (They should already be from initialize_state, but verify) + if initial_state.n_systems != 1: + raise ValueError("Initial state must be a single-system SimState") + if final_state.n_systems != 1: + raise ValueError("Final state must be a single-system SimState") + + # 1b. Calculate endpoint energies/forces (needed for tangent calculation) + # Note: Forces aren't strictly needed here but model usually returns both + logger.info("Calculating endpoint energies...") + # Concatenate expects a list of SimStates (or subclasses) + endpoint_states = concatenate_states([initial_state, final_state]) + endpoint_output = self.model(endpoint_states) + initial_energy = endpoint_output["energy"][0] + final_energy = endpoint_output["energy"][1] + logger.info( + f"Initial Energy: {initial_energy:.4f}, Final Energy: {final_energy:.4f}" + ) + + # 2. Create initial interpolated path (movable images only) + interpolated_images = self._interpolate_path(initial_state, final_state) + + # 3. Initialize optimizer state for the movable images + # Use the generic initializer with model parameter + opt_state = self._init_fn( + interpolated_images, self.model, **self._init_kwargs + ) + + # 4. Optimization loop + logger.info(f"Running NEB for max {max_steps} steps or fmax < {fmax} eV/Ang.") + + # Context manager for trajectory writing + traj_context = ( + TorchSimTrajectory(self.trajectory_filename, mode="w") + if self.trajectory_filename + else nullcontext() # Use a dummy context if no filename + ) + + with traj_context as traj: + for step in range(max_steps): + # a. Get current true forces and energies + true_forces = opt_state.forces + true_energies = opt_state.energy + + # b. Calculate NEB forces + # Concatenate states - ensures consistent group ID (0 for single NEB) + full_path_state_calc = concatenate_states( + [initial_state, opt_state, final_state] + ) + # Store true forces *before* calculating NEB forces + true_forces_for_traj = opt_state.forces.clone() + + # Get forces and potentially the step 0 debug data + neb_forces, step0_debug_data = self._calculate_neb_forces( + full_path_state_calc, + true_forces, # Pass the forces from the start of the step + true_energies, + initial_energy, + final_energy, + step=step, + ) + + # c. Update the forces in the FIRE state object with NEB forces + opt_state.forces = neb_forces + neb_forces_for_traj = neb_forces.clone() + + # d. Perform optimization step + # Use the generic step function with model parameter + opt_state = self._step_fn(opt_state, self.model, **self._step_kwargs) + + # *** Store Step 0 Debug Data AFTER optimizer step *** + if step == 0 and step0_debug_data: + logger.info("Storing Step 0 TorchSim debug data.") + self._step0_debug_output = step0_debug_data + # *************************************************** + + # e. Write to trajectory (if enabled) + if self.trajectory_filename is not None: # Use explicit check + # Create the full path state for writing (including endpoints) + current_full_path = concatenate_states( + [initial_state, opt_state, final_state] + ) + # Write arrays directly using traj.write_arrays + data_to_write = { + "positions": current_full_path.positions, + # Add forces - Need to handle endpoints (no NEB forces) + # Pad NEB forces with zeros for endpoints + "neb_forces": torch.cat( + [ + torch.zeros_like(initial_state.positions), + neb_forces_for_traj, + torch.zeros_like(final_state.positions), + ], + dim=0, + ), + # True forces are only calculated for intermediate images + # Need forces for endpoints from the initial calculation + "true_forces": torch.cat( + [ + endpoint_output["forces"][ + : initial_state.n_atoms + ], # Initial forces + true_forces_for_traj, # Intermediate forces + endpoint_output["forces"][ + initial_state.n_atoms : + ], # Final forces + ], + dim=0, + ), + "energies": torch.cat( + [ + initial_energy.unsqueeze(0), + opt_state.energy, # Energies *after* the step + final_energy.unsqueeze(0), + ], + dim=0, + ), + } + if step == 0: # Write static data only on the first step + # Assuming fixed cell NEB, cell is static + data_to_write["cell"] = current_full_path.cell + # These should also be static for the whole band + data_to_write["atomic_numbers"] = current_full_path.atomic_numbers + data_to_write["masses"] = current_full_path.masses + # Convert bool to tensor for saving + data_to_write["pbc"] = torch.tensor(current_full_path.pbc) + # Save the system_idx tensor to map atoms to images + data_to_write["image_indices"] = current_full_path.system_idx + + traj.write_arrays(data_to_write, steps=step) + + # f. Check convergence + max_force_magnitude = torch.sqrt((neb_forces**2).sum(dim=-1)).max() + max_intermediate_energy = opt_state.energy.max() + logger.info( + f"Step {step + 1:4d}: Max Force = {max_force_magnitude:.4f} Max Energy = {max_intermediate_energy:.4f}" + # f"Energy = {fire_state.energy.mean():.4f} eV (mean per image), " # Removed mean energy for brevity + ) + if max_force_magnitude < fmax: + logger.info("NEB optimization converged.") + break + else: # Loop finished without break + logger.warning("NEB optimization did not converge within max_steps.") + + # 5. Return the final path (including endpoints) + # --- Write Step 0 Debug Dictionary AFTER loop finishes --- + if self._step0_debug_output: + output_filename_ts = "torchsim_step0_debug.pkl" # Change extension + logger.info( + f"Attempting to write final Step 0 TorchSim debug data to {output_filename_ts}" + ) + try: + with open(output_filename_ts, "wb") as f: # Use 'wb' for pickle + pickle.dump(self._step0_debug_output, f) + f.flush() + os.fsync(f.fileno()) + logger.info( + f"--- TorchSim NEB Debug Info (Step 0) saved to {output_filename_ts} ---" + ) + except Exception as e: + logger.error( + f"ERROR WRITING FINAL TORCHSIM STEP 0 DEBUG PICKLE: {e}", + exc_info=True, + ) + else: + logger.warning("No Step 0 TorchSim debug data was stored to write.") + # ---------------------------------------------------------- + + return concatenate_states([initial_state, opt_state, final_state]) From 2d4dedb32758cc4bfa45f0a5cc90558c13d7019a Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 23 Apr 2026 21:00:16 -0400 Subject: [PATCH 2/7] lint --- examples/scripts/9_neb.py | 2 -- torch_sim/workflows/neb.py | 36 ++++++++++++++++++------------------ 2 files changed, 18 insertions(+), 20 deletions(-) diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py index aa26eff72..b5de994cf 100644 --- a/examples/scripts/9_neb.py +++ b/examples/scripts/9_neb.py @@ -23,12 +23,10 @@ import numpy as np import torch from ase.build import bulk -from ase.io import read from ase.mep import NEB as ASENEB from ase.mep.neb import ImprovedTangentMethod, NEBState from ase.optimize import FIRE from mace.calculators.foundations_models import mace_mp -from mace.calculators.mace import MACECalculator from monty.json import MontyDecoder, MontyEncoder # Import Monty import torch_sim as ts diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py index 8be7027be..a85973656 100644 --- a/torch_sim/workflows/neb.py +++ b/torch_sim/workflows/neb.py @@ -26,11 +26,7 @@ gradient_descent_step, ) from torch_sim.optimizers.cell_filters import CellFilter -from torch_sim.state import ( - SimState, - concatenate_states, - initialize_state, -) +from torch_sim.state import SimState, concatenate_states, initialize_state from torch_sim.trajectory import TorchSimTrajectory from torch_sim.transforms import minimum_image_displacement from torch_sim.typing import StateLike @@ -57,11 +53,7 @@ def _extract_kwargs_from_params( """ exclude = exclude or {"state", "model"} sig = inspect.signature(func) - return { - k: v - for k, v in params.items() - if k in sig.parameters and k not in exclude - } + return {k: v for k, v in params.items() if k in sig.parameters and k not in exclude} @dataclass @@ -216,9 +208,13 @@ def _interpolate_path( raise ValueError("Initial and final states must have the same atom types.") # Compare PBC values properly (can be bool, list, or tensor) pbc_match = False - if isinstance(initial_state.pbc, torch.Tensor) and isinstance(final_state.pbc, torch.Tensor): + if isinstance(initial_state.pbc, torch.Tensor) and isinstance( + final_state.pbc, torch.Tensor + ): pbc_match = torch.equal(initial_state.pbc, final_state.pbc) - elif isinstance(initial_state.pbc, torch.Tensor) or isinstance(final_state.pbc, torch.Tensor): + elif isinstance(initial_state.pbc, torch.Tensor) or isinstance( + final_state.pbc, torch.Tensor + ): # One is tensor, one is not - convert both to tensors for comparison initial_pbc_tensor = ( initial_state.pbc @@ -278,8 +274,12 @@ def _interpolate_path( ) # Shape: [n_images, 3, 3] # Create system_idx tensor: [0, 0, ..., 1, 1, ..., n_images-1, ...] - system_indices = torch.arange(self.n_images, device=self.device, dtype=torch.int64) - all_system_idx = torch.repeat_interleave(system_indices, repeats=n_atoms_per_image) + system_indices = torch.arange( + self.n_images, device=self.device, dtype=torch.int64 + ) + all_system_idx = torch.repeat_interleave( + system_indices, repeats=n_atoms_per_image + ) return SimState( positions=all_positions, @@ -447,7 +447,9 @@ def _calculate_neb_forces( cell = path_state.cell[0] # Shape [3, 3] # Convert pbc to bool if it's a tensor (for _compute_tangents) if isinstance(path_state.pbc, torch.Tensor): - pbc_bool: bool = bool(path_state.pbc.any().item()) # True if any dimension has PBC + pbc_bool: bool = bool( + path_state.pbc.any().item() + ) # True if any dimension has PBC elif isinstance(path_state.pbc, bool): pbc_bool = path_state.pbc elif isinstance(path_state.pbc, list): @@ -740,9 +742,7 @@ def run( # 3. Initialize optimizer state for the movable images # Use the generic initializer with model parameter - opt_state = self._init_fn( - interpolated_images, self.model, **self._init_kwargs - ) + opt_state = self._init_fn(interpolated_images, self.model, **self._init_kwargs) # 4. Optimization loop logger.info(f"Running NEB for max {max_steps} steps or fmax < {fmax} eV/Ang.") From 99fffdd72f2cd62a942df37273c11546369d5970 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Thu, 23 Apr 2026 22:06:03 -0400 Subject: [PATCH 3/7] wip --- examples/scripts/9_neb.py | 26 ++++++----------- torch_sim/workflows/neb.py | 59 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 20 deletions(-) diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py index b5de994cf..66fc9666b 100644 --- a/examples/scripts/9_neb.py +++ b/examples/scripts/9_neb.py @@ -202,17 +202,17 @@ def ase_neb(start_atoms, end_atoms, nimages=5): neb_calc = ASENEB(images, climb=True, method="improvedtangent") neb_calc.interpolate(mic=True) - # Attach calculator to all images using mace_mp ase_dtype_str = "float64" if torch_sim_dtype == torch.float64 else "float32" - print(f"Attaching ASE calculator with dtype: {ase_dtype_str} to all images") - ase_calc = mace_mp( - model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=ase_dtype_str, - dispersion=False, + print( + f"Attaching independent ASE calculators with dtype: {ase_dtype_str} to all images" ) for image in neb_calc.images: - image.calc = ase_calc + image.calc = mace_mp( + model=MaceUrls.mace_mpa_medium, + device=device, + default_dtype=ase_dtype_str, + dispersion=False, + ) # Set up trajectory logging for the reference ASE run (Commented out as not used for plot) # ase_traj_filename = "ase_ref_neb.traj" @@ -730,15 +730,7 @@ def print_pickle_structure(filename="torchsim_step0_debug.pkl"): print("Finished torch-sim NEB optimization.") # Check if it converged and plot results -results = ts_mace_model( - dict( - positions=final_path_gd.positions, - cell=final_path_gd.cell, - atomic_numbers=final_path_gd.atomic_numbers, - system_idx=final_path_gd.system_idx, - pbc=True, - ) -) +results = ts_mace_model(final_path_gd) energies = results["energy"].tolist() diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py index a85973656..f40f9dc03 100644 --- a/torch_sim/workflows/neb.py +++ b/torch_sim/workflows/neb.py @@ -733,6 +733,24 @@ def run( endpoint_output = self.model(endpoint_states) initial_energy = endpoint_output["energy"][0] final_energy = endpoint_output["energy"][1] + # Distribute model extras (e.g. interaction_energy) back onto the + # endpoint states so that subsequent concatenate_states calls with + # opt_state (which carries those extras) produce consistent leading dims + n_init_atoms = initial_state.n_atoms + n_final_atoms = final_state.n_atoms + init_extras: dict[str, torch.Tensor] = {} + final_extras: dict[str, torch.Tensor] = {} + for key, val in endpoint_output.items(): + if key in {"energy", "forces", "stress"} or not isinstance(val, torch.Tensor): + continue + if val.shape[0] == 2: + init_extras[key] = val[:1] + final_extras[key] = val[1:] + elif val.shape[0] == n_init_atoms + n_final_atoms: + init_extras[key] = val[:n_init_atoms] + final_extras[key] = val[n_init_atoms:] + initial_state.store_model_extras(init_extras) + final_state.store_model_extras(final_extras) logger.info( f"Initial Energy: {initial_energy:.4f}, Final Energy: {final_energy:.4f}" ) @@ -754,6 +772,35 @@ def run( else nullcontext() # Use a dummy context if no filename ) + def _opt_state_as_simstate(state: SimState) -> SimState: + """Project an OptimState/FireState down to a plain SimState. + + Concatenating an OptimState/FireState with plain SimState endpoints + collapses to the first state's class (SimState), causing optimizer- + specific fields like velocities/forces/energy to be misrouted into + extras with mismatched leading dims. We strip those here and + preserve only model-derived extras (interaction_energy, etc.) that + were also populated on the endpoints. + """ + optim_only_atom = {"forces"} + optim_only_system = {"energy", "stress", "dt", "alpha", "n_pos"} + sys_extras = { + k: v for k, v in state.system_extras.items() if k not in optim_only_system + } + atom_extras = { + k: v for k, v in state.atom_extras.items() if k not in optim_only_atom + } + return SimState( + positions=state.positions, + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + atomic_numbers=state.atomic_numbers, + system_idx=state.system_idx, + _system_extras=sys_extras, + _atom_extras=atom_extras, + ) + with traj_context as traj: for step in range(max_steps): # a. Get current true forces and energies @@ -763,7 +810,7 @@ def run( # b. Calculate NEB forces # Concatenate states - ensures consistent group ID (0 for single NEB) full_path_state_calc = concatenate_states( - [initial_state, opt_state, final_state] + [initial_state, _opt_state_as_simstate(opt_state), final_state] ) # Store true forces *before* calculating NEB forces true_forces_for_traj = opt_state.forces.clone() @@ -796,7 +843,11 @@ def run( if self.trajectory_filename is not None: # Use explicit check # Create the full path state for writing (including endpoints) current_full_path = concatenate_states( - [initial_state, opt_state, final_state] + [ + initial_state, + _opt_state_as_simstate(opt_state), + final_state, + ] ) # Write arrays directly using traj.write_arrays data_to_write = { @@ -884,4 +935,6 @@ def run( logger.warning("No Step 0 TorchSim debug data was stored to write.") # ---------------------------------------------------------- - return concatenate_states([initial_state, opt_state, final_state]) + return concatenate_states( + [initial_state, _opt_state_as_simstate(opt_state), final_state] + ) From 0f6329ec468d059dc34fda3ed40b7f0501c07199 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Sat, 9 May 2026 22:03:04 -0400 Subject: [PATCH 4/7] messy with group_idx. Need to reconsider design. --- examples/scripts/10_fire.py | 207 +++++ examples/scripts/9_neb.py | 1169 ++++++------------------- tests/test_optimizers.py | 30 + tests/test_state.py | 27 +- tests/workflows/test_neb.py | 161 ++++ torch_sim/autobatching.py | 79 +- torch_sim/optimizers/cell_filters.py | 5 +- torch_sim/optimizers/fire.py | 152 +++- torch_sim/optimizers/state.py | 2 +- torch_sim/state.py | 136 ++- torch_sim/workflows/neb.py | 1177 ++++++++------------------ 11 files changed, 1347 insertions(+), 1798 deletions(-) create mode 100644 examples/scripts/10_fire.py create mode 100644 tests/workflows/test_neb.py diff --git a/examples/scripts/10_fire.py b/examples/scripts/10_fire.py new file mode 100644 index 000000000..4c0fa5061 --- /dev/null +++ b/examples/scripts/10_fire.py @@ -0,0 +1,207 @@ +"""Compare plain ASE FIRE and torch-sim ase_fire on one analytic system.""" +# ruff: noqa: D101, D102, D103, D107 + +# %% +# /// script +# dependencies = [ +# "ase", +# "matplotlib", +# ] +# /// + +from dataclasses import dataclass +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase import Atoms +from ase.calculators.calculator import Calculator, all_changes +from ase.optimize import FIRE + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface + + +@dataclass(frozen=True) +class PotentialParams: + valley_scale: float = 5.0 + valley_curve: float = 0.5 + + +def energy_forces( + positions: torch.Tensor, params: PotentialParams +) -> tuple[torch.Tensor, torch.Tensor]: + """Return per-atom energies and forces for a curved double well.""" + x = positions[:, 0] + y = positions[:, 1] + z = positions[:, 2] + u = x**2 - 1.0 + v = y - params.valley_curve * u + energy = u**2 + params.valley_scale * v**2 + z**2 + dE_dx = 4.0 * x * u - 4.0 * params.valley_scale * params.valley_curve * x * v + dE_dy = 2.0 * params.valley_scale * v + dE_dz = 2.0 * z + forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) + return energy, forces + + +class TorchModel(ModelInterface): + def __init__(self, params: PotentialParams) -> None: + super().__init__() + self._device = torch.device("cpu") + self._dtype = torch.float64 + self._compute_forces = True + self._compute_stress = True + self.params = params + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + per_atom_energy, forces = energy_forces(state.positions, self.params) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": forces, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +class ASECalculator(Calculator): + implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + + def __init__(self, params: PotentialParams) -> None: + super().__init__() + self.params = params + + def calculate( + self, + atoms: Atoms | None = None, + properties: list[str] | None = None, + system_changes: list[str] = all_changes, + ) -> None: + super().calculate(atoms, properties, system_changes) + positions = torch.tensor(self.atoms.positions, dtype=torch.float64) + per_atom_energy, forces = energy_forces(positions, self.params) + self.results["energy"] = float(per_atom_energy.sum()) + self.results["forces"] = forces.detach().cpu().numpy() + + +def make_state(position: tuple[float, float, float]) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([position], dtype=torch.float64), + masses=torch.ones(1, dtype=torch.float64), + cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18]), + system_idx=torch.zeros(1, dtype=torch.long), + ) + + +def run_torch_fire( + state: ts.SimState, model: ModelInterface, *, steps: int, fmax: float +) -> tuple[ts.SimState, list[float], list[float]]: + energy_history: list[float] = [] + fmax_history: list[float] = [] + + def record(state: ts.OptimState) -> None: + energy_history.append(float(state.energy[0])) + fmax_history.append(float(torch.linalg.norm(state.forces, dim=1).max())) + + initial_opt_state = ts.fire_init(state, model, fire_flavor="ase_fire") + record(initial_opt_state) + + def convergence_fn(state: ts.OptimState, last_energy: torch.Tensor) -> torch.Tensor: + del last_energy + record(state) + return ts.generate_force_convergence_fn(force_tol=fmax)(state, state.energy) + + result = ts.optimize( + state, + model, + optimizer=ts.Optimizer.fire, + convergence_fn=convergence_fn, + max_steps=steps, + steps_between_swaps=1, + autobatcher=False, + fire_flavor="ase_fire", + ) + return result, energy_history, fmax_history + + +def run_ase_fire( + atoms: Atoms, *, params: PotentialParams, steps: int, fmax: float +) -> tuple[Atoms, list[float], list[float]]: + atoms = atoms.copy() + atoms.calc = ASECalculator(params) + optimizer = FIRE(atoms, logfile=None) + energy_history: list[float] = [] + fmax_history: list[float] = [] + + def record() -> None: + energy_history.append(float(atoms.get_potential_energy())) + fmax_history.append(float(np.linalg.norm(atoms.get_forces(), axis=1).max())) + + optimizer.attach(record, interval=1) + optimizer.run(fmax=fmax, steps=steps) + return atoms, energy_history, fmax_history + + +params = PotentialParams() +steps = 80 +fmax = 0.03 +initial_position = (-0.2, 0.9, 0.0) +model = TorchModel(params) +state = make_state(initial_position) +atoms = Atoms("Ar", positions=[initial_position], cell=np.eye(3) * 10.0, pbc=False) + +ts_final, ts_energy, ts_force = run_torch_fire(state, model, steps=steps, fmax=fmax) +ase_final, ase_energy, ase_force = run_ase_fire( + atoms, params=params, steps=steps, fmax=fmax +) + +ts_position = ts_final.positions.detach().cpu().numpy()[0] +ase_position = ase_final.positions[0] +print(f"torch-sim steps: {len(ts_force)}") +print(f"ASE steps: {len(ase_force)}") +print(f"final position abs diff: {np.max(np.abs(ts_position - ase_position)):.3e}") +print(f"final energy abs diff: {abs(ts_energy[-1] - ase_energy[-1]):.3e}") +print(f"final fmax ts/ase: {ts_force[-1]:.3e} / {ase_force[-1]:.3e}") + +common_steps = min(len(ts_energy), len(ase_energy)) +step_axis = np.arange(common_steps) +energy_residual = np.array(ts_energy[:common_steps]) - np.array(ase_energy[:common_steps]) +force_residual = np.array(ts_force[:common_steps]) - np.array(ase_force[:common_steps]) + +fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex="col") +axes[0, 0].plot(ts_energy, label="torch-sim") +axes[0, 0].plot(ase_energy, "--", label="ASE") +axes[0, 0].set_ylabel("Energy") +axes[0, 0].set_title("Plain FIRE energy") +axes[0, 0].legend() + +axes[0, 1].plot(ts_force, label="torch-sim") +axes[0, 1].plot(ase_force, "--", label="ASE") +axes[0, 1].axhline(fmax, color="k", linestyle=":", label="fmax") +axes[0, 1].set_ylabel("Max force") +axes[0, 1].set_yscale("log") +axes[0, 1].set_title("Plain FIRE convergence") +axes[0, 1].legend() + +axes[1, 0].axhline(0.0, color="k", linewidth=0.8) +axes[1, 0].plot(step_axis, energy_residual) +axes[1, 0].set_xlabel("Optimization step") +axes[1, 0].set_ylabel("TS - ASE") +axes[1, 0].set_title("Energy residual") + +axes[1, 1].axhline(0.0, color="k", linewidth=0.8) +axes[1, 1].plot(step_axis, force_residual) +axes[1, 1].set_xlabel("Optimization step") +axes[1, 1].set_ylabel("TS - ASE") +axes[1, 1].set_title("Max-force residual") + +fig.tight_layout() +fig.savefig("fire_ase_torchsim_comparison.png", dpi=200) +print("Saved comparison plot to fire_ase_torchsim_comparison.png") diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py index 66fc9666b..5462e3acb 100644 --- a/examples/scripts/9_neb.py +++ b/examples/scripts/9_neb.py @@ -1,924 +1,305 @@ -"""Nudged Elastic Band (NEB) workflow. +"""Compare torch-sim and ASE Nudged Elastic Band trajectories.""" +# ruff: noqa: D101, D102, D103, D107 -This script demonstrates the Nudged Elastic Band method for finding minimum energy -paths between two given atomic configurations. -""" # %% # /// script # dependencies = [ -# "mace-torch>=0.3.12", # "ase", +# "matplotlib", # ] # /// -import json # Import json for output +from dataclasses import dataclass +from typing import ClassVar -# Configure logging to DEBUG level first -import logging -import pickle # Import pickle - -import ase.geometry # Import the geometry module -import h5py import matplotlib.pyplot as plt import numpy as np import torch -from ase.build import bulk +from ase import Atoms +from ase.calculators.calculator import Calculator, all_changes from ase.mep import NEB as ASENEB -from ase.mep.neb import ImprovedTangentMethod, NEBState from ase.optimize import FIRE -from mace.calculators.foundations_models import mace_mp -from monty.json import MontyDecoder, MontyEncoder # Import Monty import torch_sim as ts -from torch_sim.models.mace import MaceModel, MaceUrls -from torch_sim.state import SimState -from torch_sim.workflows.neb import NEB as TorchNEB - - -# Redirect logging to a file instead of stdout -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s - %(name)s - %(message)s", - filename="neb_debug.log", # Specify the log file name - filemode="w", -) # Overwrite the log file each time -logging.getLogger("torch_sim.workflows.neb").setLevel(logging.DEBUG) - - -torch_sim_device = "cuda" if torch.cuda.is_available() else "cpu" -torch_sim_dtype = torch.float64 # Use float64 for higher precision - -# Load MACE model using mace_mp like other tutorials -print("Loading MACE model...") -mace_potential = mace_mp( - model=MaceUrls.mace_mpa_medium, - return_raw_model=True, - default_dtype=str(torch_sim_dtype).removeprefix("torch."), - device=str(torch_sim_device), +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import fire_init, fire_step +from torch_sim.workflows.neb import ( + as_sim_state, + assemble_path, + interpolate_path, + neb_convergence_fn, + neb_init, + neb_step, ) -def compare_initial_paths( - ase_start_atoms, - ase_end_atoms, - torch_sim_initial_state: SimState, - torch_sim_final_state: SimState, - neb_workflow: TorchNEB, -): - """Compares initial paths and the MIC displacement vector.""" - print("Comparing initial interpolated paths and MIC vectors...") - n_images = neb_workflow.n_images - n_total_images = n_images + 2 - device = neb_workflow.device - dtype = neb_workflow.dtype - - # --- Endpoint Check --- - print("\nChecking consistency of starting endpoint positions:") - ase_start_pos_direct = ase_start_atoms.get_positions() - ts_start_pos_direct = torch_sim_initial_state.positions.cpu().numpy() - start_close = np.allclose( - ase_start_pos_direct, ts_start_pos_direct, rtol=1e-5, atol=1e-6 +@dataclass(frozen=True) +class CurvedDoubleWellParams: + valley_scale: float = 5.0 + valley_curve: float = 0.5 + + +def curved_double_well( + positions: torch.Tensor, params: CurvedDoubleWellParams +) -> tuple[torch.Tensor, torch.Tensor]: + """Return per-atom energies and forces for a curved double-well surface.""" + x = positions[:, 0] + y = positions[:, 1] + z = positions[:, 2] + u = x**2 - 1.0 + v = y - params.valley_curve * u + energy = u**2 + params.valley_scale * v**2 + z**2 + dE_dx = 4.0 * x * u - 4.0 * params.valley_scale * params.valley_curve * x * v + dE_dy = 2.0 * params.valley_scale * v + dE_dz = 2.0 * z + forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) + return energy, forces + + +class TorchCurvedDoubleWellModel(ModelInterface): + def __init__( + self, + *, + device: torch.device, + dtype: torch.dtype, + params: CurvedDoubleWellParams, + ) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = True + self.params = params + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + per_atom_energy, forces = curved_double_well(state.positions, self.params) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": forces, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +class ASECurvedDoubleWellCalculator(Calculator): + implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + + def __init__(self, params: CurvedDoubleWellParams) -> None: + super().__init__() + self.params = params + + def calculate( + self, + atoms: Atoms | None = None, + properties: list[str] | None = None, + system_changes: list[str] = all_changes, + ) -> None: + super().calculate(atoms, properties, system_changes) + positions = torch.tensor(self.atoms.positions, dtype=torch.float64) + per_atom_energy, forces = curved_double_well(positions, self.params) + self.results["energy"] = float(per_atom_energy.sum().item()) + self.results["forces"] = forces.detach().cpu().numpy() + + +def make_state( + position: tuple[float, float, float], *, device: torch.device +) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([position], device=device, dtype=torch.float64), + masses=torch.ones(1, device=device, dtype=torch.float64), + cell=torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18], device=device), + system_idx=torch.zeros(1, device=device, dtype=torch.long), ) - print(f" Direct Start positions close: {start_close}") - if not start_close: - max_diff_start = np.max(np.abs(ase_start_pos_direct - ts_start_pos_direct)) - print(f" Max absolute difference (Start): {max_diff_start:.6f}") - print("------------------------------------") - - # --- MIC Vector Comparison --- - print("\nComparing Minimum Image Convention (MIC) displacement vectors:") - # Use the torch-sim states as the source of truth for positions/cell - raw_dr_ts = torch_sim_final_state.positions - torch_sim_initial_state.positions - cell_ts = torch_sim_initial_state.cell[0] # Assuming single batch cell - pbc_ts = torch_sim_initial_state.pbc - - # ASE MIC calculation - try: - ase_cell_np = cell_ts.cpu().numpy() - ase_pbc_np = np.array([pbc_ts] * 3) # ASE expects 3 bools usually - ase_mic_dr_np, _ = ase.geometry.find_mic( - raw_dr_ts.cpu().numpy(), ase_cell_np, pbc=ase_pbc_np - ) - print(f" ASE MIC vector calculated (shape: {ase_mic_dr_np.shape})") - except Exception as e: - print(f" Error calculating ASE MIC: {e}") - ase_mic_dr_np = None - - # torch-sim MIC calculation - try: - ts_mic_dr = ts.transforms.minimum_image_displacement( - dr=raw_dr_ts, cell=cell_ts, pbc=pbc_ts - ) - ts_mic_dr_np = ts_mic_dr.cpu().numpy() - print(f" torch-sim MIC vector calculated (shape: {ts_mic_dr_np.shape})") - except Exception as e: - print(f" Error calculating torch-sim MIC: {e}") - ts_mic_dr_np = None - - # Compare the MIC vectors - if ase_mic_dr_np is not None and ts_mic_dr_np is not None: - if ase_mic_dr_np.shape != ts_mic_dr_np.shape: - print(" Error: Shapes of MIC vectors do not match.") - else: - mic_vectors_close = np.allclose( - ase_mic_dr_np, ts_mic_dr_np, rtol=1e-5, atol=1e-6 - ) - print(f" MIC displacement vectors close: {mic_vectors_close}") - if not mic_vectors_close: - max_diff_mic = np.max(np.abs(ase_mic_dr_np - ts_mic_dr_np)) - norm_diff = np.linalg.norm(ase_mic_dr_np - ts_mic_dr_np) - print(f" Max absolute difference (MIC vectors): {max_diff_mic:.6f}") - print(f" Norm of difference vector (MIC): {norm_diff:.6f}") - print(" This difference likely causes the interpolation discrepancy.") - print("------------------------------------") - - # --- Get ASE interpolated path --- - ase_images = [ase_start_atoms.copy() for _ in range(n_images + 1)] - ase_images.append(ase_end_atoms.copy()) - ase_neb_calc = ASENEB(ase_images, climb=False) - ase_neb_calc.interpolate(mic=True) - ase_positions = np.stack([img.get_positions() for img in ase_neb_calc.images]) - print(f"\n ASE interpolated path shape: {ase_positions.shape}") - - # --- Get torch-sim interpolated path --- - try: - interpolated_state = neb_workflow._interpolate_path( - torch_sim_initial_state, torch_sim_final_state - ) - ts_interp_pos = interpolated_state.positions - ts_start_pos = torch_sim_initial_state.positions - ts_end_pos = torch_sim_final_state.positions - n_atoms = ts_start_pos.shape[0] - ts_interp_pos_reshaped = ts_interp_pos.reshape(n_images, n_atoms, 3) - ts_positions = torch.cat( - [ - torch_sim_initial_state.positions.unsqueeze(0).to(device, dtype), - ts_interp_pos_reshaped.to(device, dtype), - torch_sim_final_state.positions.unsqueeze(0).to(device, dtype), - ], - dim=0, - ) - ts_positions_np = ts_positions.cpu().numpy() - print(f" torch-sim interpolated path shape (direct): {ts_positions_np.shape}") - except Exception as e: - print(f" Error during torch-sim interpolation: {e}") - import traceback - - traceback.print_exc() - return - - # --- Compare Interpolated Paths --- - print( - "\n Per-image comparison of interpolated paths (Max Abs Error | Mean Abs Error):" + + +def relative_energies_torch(state: ts.SimState, model: ModelInterface) -> np.ndarray: + energies = model(state)["energy"].detach().cpu().numpy() + return energies - energies[0] + + +def run_torch_sim_neb( + initial_state: ts.SimState, + final_state: ts.SimState, + model: ModelInterface, + *, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[ts.SimState, list[np.ndarray], list[float]]: + movable_images = interpolate_path(initial_state, final_state, n_images) + endpoint_output = model( + ts.concatenate_states([as_sim_state(initial_state), as_sim_state(final_state)]) ) - overall_max_diff_interp = 0.0 - if ase_positions.shape != ts_positions_np.shape: - print(" Error: Shapes of ASE and torch-sim interpolated paths do not match.") - return - - for i in range(n_total_images): - diff_image_i = np.abs(ase_positions[i] - ts_positions_np[i]) - max_ae_i = np.max(diff_image_i) - mae_i = np.mean(diff_image_i) - print(f" Image {i}: MaxAE = {max_ae_i:.6f} | MAE = {mae_i:.6f}") - overall_max_diff_interp = max(overall_max_diff_interp, max_ae_i) - - are_close_interp = np.allclose(ase_positions, ts_positions_np, rtol=1e-5, atol=1e-6) - - if are_close_interp: - print(" Overall: Interpolated paths are numerically close.") - else: - print(" Overall: Interpolated paths differ numerically.") - print( - f" Overall Maximum absolute difference (Interpolated): {overall_max_diff_interp:.6f}" - ) - - -def ase_neb(start_atoms, end_atoms, nimages=5): - device = "cuda" if torch.cuda.is_available() else "cpu" - images = [start_atoms.copy() for _ in range(nimages + 1)] - images.append(end_atoms.copy()) - - neb_calc = ASENEB(images, climb=True, method="improvedtangent") - neb_calc.interpolate(mic=True) - - ase_dtype_str = "float64" if torch_sim_dtype == torch.float64 else "float32" - print( - f"Attaching independent ASE calculators with dtype: {ase_dtype_str} to all images" + endpoint_kwargs = { + "initial_state": as_sim_state(initial_state), + "final_state": as_sim_state(final_state), + "initial_energy": endpoint_output["energy"][0], + "final_energy": endpoint_output["energy"][1], + "spring_constant": spring_constant, + "use_climbing_image": True, + } + energy_history: list[np.ndarray] = [] + max_force_history: list[float] = [] + + def record(state: ts.SimState) -> None: + full_path = assemble_path(initial_state, state, final_state) + energy_history.append(relative_energies_torch(full_path, model)) + max_force_history.append(float(torch.linalg.norm(state.forces, dim=1).max())) + + def convergence(state: ts.OptimState, last_energy: torch.Tensor) -> torch.Tensor: + record(state) + return neb_convergence_fn(state, last_energy, fmax=fmax) + + initial_opt_state = neb_init( + movable_images, + model, + **endpoint_kwargs, + base_init_fn=fire_init, + base_init_kwargs={"fire_flavor": "ase_fire"}, ) - for image in neb_calc.images: - image.calc = mace_mp( - model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=ase_dtype_str, - dispersion=False, - ) - - # Set up trajectory logging for the reference ASE run (Commented out as not used for plot) - # ase_traj_filename = "ase_ref_neb.traj" - opt = FIRE(neb_calc) - # opt.attach(traj) # Attach the trajectory logger - - # Run the ASE optimization (essential) - print("Running ASE NEB optimization...") - opt.run(fmax=0.05, steps=1000) - print("Finished ASE NEB optimization.") - - return neb_calc # Only return the final NEB object - - -def relax_atoms( - atoms, - fmax=0.05, - steps=1000, - device=torch_sim_device, - dtype=torch_sim_dtype, -): - new_atoms = atoms.copy() - ase_dtype_str = "float64" if dtype == torch.float64 else "float32" - new_atoms.calc = mace_mp( - model=MaceUrls.mace_mpa_medium, - device=str(device), - default_dtype=ase_dtype_str, - dispersion=False, + record(initial_opt_state) + + final_movable = ts.optimize( + movable_images, + model, + optimizer=(neb_init, neb_step), + convergence_fn=convergence, + max_steps=max_steps, + steps_between_swaps=1, + autobatcher=False, + init_kwargs={ + **endpoint_kwargs, + "base_init_fn": fire_init, + "base_init_kwargs": {"fire_flavor": "ase_fire"}, + }, + **endpoint_kwargs, + base_step_fn=fire_step, + base_step_kwargs={"fire_flavor": "ase_fire"}, ) - opt = FIRE(new_atoms) - opt.run(fmax=fmax, steps=steps) - return new_atoms - - -# Create the torch_sim wrapper -ts_mace_model = MaceModel( - model=mace_potential, - device=torch_sim_device, - dtype=torch_sim_dtype, - compute_forces=True, # Default, but good to be explicit - compute_stress=True, # Needed by interface if we want stress later - enable_cueq=False, + final_path = assemble_path(initial_state, final_movable, final_state) + return final_path, energy_history, max_force_history + + +def run_ase_neb( + initial_atoms: Atoms, + final_atoms: Atoms, + *, + params: CurvedDoubleWellParams, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[list[Atoms], list[np.ndarray], list[float]]: + images = [initial_atoms.copy()] + images.extend(initial_atoms.copy() for _ in range(n_images)) + images.append(final_atoms.copy()) + for image in images: + image.calc = ASECurvedDoubleWellCalculator(params) + + neb = ASENEB(images, k=spring_constant, climb=True, method="improvedtangent") + neb.interpolate(mic=True) + optimizer = FIRE(neb, logfile=None) + energy_history: list[np.ndarray] = [] + max_force_history: list[float] = [] + + def record() -> None: + energies = np.array([image.get_potential_energy() for image in images]) + energy_history.append(energies - energies[0]) + forces = neb.get_forces().reshape(-1, 3) + max_force_history.append(float(np.linalg.norm(forces, axis=1).max())) + + optimizer.attach(record, interval=1) + optimizer.run(fmax=fmax, steps=max_steps) + return images, energy_history, max_force_history + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +params = CurvedDoubleWellParams() +n_images = 7 +spring_constant = 0.1 +max_steps = 200 +fmax = 0.03 + +initial_state = make_state((-1.0, 0.0, 0.0), device=device) +final_state = make_state((1.0, 0.0, 0.0), device=device) +model = TorchCurvedDoubleWellModel(device=device, dtype=torch.float64, params=params) + +initial_atoms = Atoms("Ar", positions=[[-1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) +final_atoms = Atoms("Ar", positions=[[1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) + +torch_path, torch_energy_history, torch_fmax = run_torch_sim_neb( + initial_state, + final_state, + model, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, ) - -# initial_trajectory = read('/home/myless/Packages/forge/scratch/data/neb_workflow_data/Cr7Ti8V104W8Zr_Cr_to_V_site102_to_69_initial.xyz', index=':') -# print(len(initial_trajectory)) - -# Create simple test structures for demonstration -# Using bulk structures instead of file paths -# Create simple test structures (can be replaced with file reads if needed) -start_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) -end_atoms = bulk("Al", "fcc", a=4.05, cubic=True).repeat((2, 2, 2)) -# Add a small displacement to create a path -end_atoms.positions[0] += [0.1, 0.1, 0.1] - -relaxed_start_atoms = relax_atoms(start_atoms) -relaxed_end_atoms = relax_atoms(end_atoms) - -traj_file_name = "neb_path_torchsim_fire_5im.hdf5" - -# --- Setup ASE NEB for comparison --- -n_intermediate_images_ase = 5 -ase_images_compare = [relaxed_start_atoms.copy()] -ase_images_compare.extend( - [relaxed_start_atoms.copy() for _ in range(n_intermediate_images_ase)] -) -ase_images_compare.append(relaxed_end_atoms.copy()) - -ase_neb_compare = ASENEB( - ase_images_compare, - k=0.1, # Match torch-sim spring constant - climb=True, # Match torch-sim setting - method="improvedtangent", # Match torch-sim tangent method -) -ase_neb_compare.interpolate(mic=True) # Initial interpolation - -device = "cuda" if torch.cuda.is_available() else "cpu" -# Attach calculator to ALL ASE images using mace_mp -ase_dtype_str_compare = "float64" if torch_sim_dtype == torch.float64 else "float32" -print(f"Using ASE comparison calculator dtype: {ase_dtype_str_compare}") -ase_calculator = mace_mp( - model=MaceUrls.mace_mpa_medium, - device=device, - default_dtype=ase_dtype_str_compare, - dispersion=False, -) -for img in ase_neb_compare.images: - img.calc = ase_calculator -# ---------------------------------- - -initial_system = ts.io.atoms_to_state( - relaxed_start_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype -) -final_system = ts.io.atoms_to_state( - relaxed_end_atoms.copy(), device=torch_sim_device, dtype=torch_sim_dtype -) - -neb_workflow = TorchNEB( - model=ts_mace_model, - device=torch_sim_device, - dtype=torch_sim_dtype, - spring_constant=0.1, - n_images=5, - use_climbing_image=True, # Set as desired for the actual run - optimizer_type="ase_fire", # Set as desired for the actual run - optimizer_params={}, - trajectory_filename=traj_file_name, -) - -compare_initial_paths( - relaxed_start_atoms, relaxed_end_atoms, initial_system, final_system, neb_workflow +ase_images, ase_energy_history, ase_fmax = run_ase_neb( + initial_atoms, + final_atoms, + params=params, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, ) +torch_final = relative_energies_torch(torch_path, model) +ase_final = ase_energy_history[-1] +reaction_coordinate = np.linspace(0.0, 1.0, n_images + 2) -# --- Add Function for Manual ASE Force Calculation --- -def calculate_ase_neb_force_step0( - ase_neb_calc: ASENEB, - image_index: int, - neb_workflow: TorchNEB, - output_filename="ase_step0_debug.json", +print("Final relative energies (eV)") +print("image torch-sim ASE abs diff") +for idx, (torch_energy, ase_energy) in enumerate( + zip(torch_final, ase_final, strict=True) ): - """Manually calculates the ASE NEB force components for a specific - intermediate image at step 0 (after initial interpolation) and saves - the results to a JSON file. - Uses the ImprovedTangent method for consistency with torch-sim default. - """ - print(f"--- Calculating ASE NEB Debug Info (Step 0, Image Index {image_index}) ---") - debug_data = { - "step": 0, - "image_index_intermediate": image_index - 1, # 0-based index among intermediates - "image_index_absolute": image_index, # 0-based index in full list - "inputs": {}, - "outputs": {}, - "error": None, - } - - n_images = ase_neb_calc.nimages # Total number of images including endpoints - if not (0 < image_index < n_images - 1): - error_msg = f"Error: image_index {image_index} is not an intermediate image." - print(error_msg) - debug_data["error"] = error_msg - with open(output_filename, "w") as f: - json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder - return - - # 1. Get initial energies and forces after interpolation + calculator attachment - try: - initial_energies_np = np.array( - [img.get_potential_energy() for img in ase_neb_calc.images] - ) - initial_forces_np = np.stack([img.get_forces() for img in ase_neb_calc.images]) - - # No need for .tolist() with MontyEncoder - debug_data["inputs"]["energies_all"] = initial_energies_np - debug_data["inputs"]["true_forces_image"] = initial_forces_np[image_index] - debug_data["inputs"]["positions_image_minus_1"] = ase_neb_calc.images[ - image_index - 1 - ].get_positions() - debug_data["inputs"]["positions_image"] = ase_neb_calc.images[ - image_index - ].get_positions() - debug_data["inputs"]["positions_image_plus_1"] = ase_neb_calc.images[ - image_index + 1 - ].get_positions() - debug_data["inputs"]["cell"] = ( - ase_neb_calc.images[image_index].get_cell().tolist() - ) - # No need for bool() conversion with MontyEncoder - debug_data["inputs"]["pbc"] = ase_neb_calc.images[image_index].get_pbc() - - except Exception as e: - error_msg = f"Error getting initial energies/forces from ASE images: {e}" - print(error_msg) - debug_data["error"] = error_msg - import traceback - - debug_data["traceback"] = traceback.format_exc() - with open(output_filename, "w") as f: - json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder - return - - # 2. Setup NEB state and method objects - ase_neb_obj_for_state = ASENEB( - ase_neb_calc.images, - k=neb_workflow.spring_constant, - climb=neb_workflow.use_climbing_image, - method="improvedtangent", - ) - neb_state = NEBState(ase_neb_obj_for_state, ase_neb_calc.images, initial_energies_np) - tangent_method = ImprovedTangentMethod(ase_neb_obj_for_state) - - # 3. Calculate components for the target image_index - try: - spring1 = neb_state.spring(image_index - 1) - spring2 = neb_state.spring(image_index) - # No .tolist() needed - debug_data["outputs"]["mic_displacement_1"] = spring1.t - debug_data["outputs"]["mic_displacement_2"] = spring2.t - - # Calculate tangent - tangent_ase = tangent_method.get_tangent(neb_state, spring1, spring2, image_index) - tangent_norm_ase = np.linalg.norm(tangent_ase) - if tangent_norm_ase > 1e-15: - tangent_ase_normalized = tangent_ase / tangent_norm_ase - else: - tangent_ase_normalized = tangent_ase # Keep as zero vector - tangent_norm_final = np.linalg.norm(tangent_ase_normalized) - - # No .tolist() needed - debug_data["outputs"]["tangent_vector"] = tangent_ase_normalized - debug_data["outputs"]["tangent_norm"] = tangent_norm_final - - # Calculate perpendicular force - true_force_img = initial_forces_np[image_index] - f_true_dot_tau_ase = np.vdot(true_force_img, tangent_ase_normalized) - f_perp_ase = true_force_img - f_true_dot_tau_ase * tangent_ase_normalized - f_perp_norm = np.linalg.norm(f_perp_ase) - - # No .tolist() needed - debug_data["outputs"]["f_true_dot_tau"] = f_true_dot_tau_ase - debug_data["outputs"]["f_perp_vector"] = f_perp_ase - debug_data["outputs"]["f_perp_norm"] = f_perp_norm - - # Calculate parallel spring force - segment_lengths_all = [neb_state.spring(i).nt for i in range(n_images - 1)] - spring_mag_term = spring2.nt * spring2.k - spring1.nt * spring1.k - f_spring_par_ase = spring_mag_term * tangent_ase_normalized - f_spring_par_norm = np.linalg.norm(f_spring_par_ase) - - # No .tolist() needed - debug_data["outputs"]["segment_lengths"] = segment_lengths_all - debug_data["outputs"]["spring_force_magnitude_term"] = spring_mag_term - debug_data["outputs"]["f_spring_par_vector"] = f_spring_par_ase - debug_data["outputs"]["f_spring_par_norm"] = f_spring_par_norm - - # Calculate total NEB force (before potential climbing modification) - neb_force_ase = f_perp_ase + f_spring_par_ase - # Explicitly convert to numpy array before saving, remove .tolist() - debug_data["outputs"]["neb_force_before_climb_vector"] = np.array(neb_force_ase) - debug_data["outputs"]["neb_force_before_climb_norm"] = np.linalg.norm( - neb_force_ase - ) - - # --- Direct Debug Prints for Step 0 --- - print("\n --- DIRECT DEBUG PRINT (ASE STEP 0) ---") - print(f" f_perp_norm: {f_perp_norm}") - print(f" f_perp_vec[0]: {f_perp_ase[0]}") - print(f" spring1_length (R[{image_index}]-R[{image_index - 1}]): {spring1.nt}") - print(f" spring2_length (R[{image_index + 1}]-R[{image_index}]): {spring2.nt}") - print(f" Length Diff (spring2.nt - spring1.nt): {spring2.nt - spring1.nt}") - print(f" f_spring_par_norm: {f_spring_par_norm}") - print(f" f_spring_par_vec[0]: {f_spring_par_ase[0]}") - print(f" neb_force_before_climb_norm: {np.linalg.norm(neb_force_ase)}") - print(" ------------------------------------") - # -------------------------------------- - - # Handle climbing image modification - is_climbing = ase_neb_obj_for_state.climb and image_index == neb_state.imax - debug_data["outputs"]["is_climbing_image"] = is_climbing - debug_data["outputs"]["imax"] = int( - neb_state.imax - ) # Ensure imax is JSON serializable - - if is_climbing: - climbing_force_ase = ( - true_force_img - 2 * f_true_dot_tau_ase * tangent_ase_normalized - ) - climbing_force_norm = np.linalg.norm(climbing_force_ase) - # No .tolist() needed - debug_data["outputs"]["climbing_force_vector"] = climbing_force_ase - debug_data["outputs"]["climbing_force_norm"] = climbing_force_norm - final_force_ase = climbing_force_ase - else: - final_force_ase = neb_force_ase - - # No .tolist() needed - debug_data["outputs"]["final_neb_force_vector"] = final_force_ase - debug_data["outputs"]["final_neb_force_norm"] = np.linalg.norm(final_force_ase) - - except Exception as e: - error_msg = ( - f"Error during manual ASE force calculation for image {image_index}: {e}" - ) - print(error_msg) - debug_data["error"] = error_msg - import traceback - - debug_data["traceback"] = traceback.format_exc() - - # Write data to JSON - try: - with open(output_filename, "w") as f: - json.dump(debug_data, f, indent=2, cls=MontyEncoder) # Use MontyEncoder - print(f"--- ASE NEB Debug Info saved to {output_filename} ---") - except Exception as e: - print(f"Error writing ASE debug info to JSON: {e}") - - -# --- Add Function for Comparing JSON/Pickle Outputs to debug the tangent force calculation --- -def compare_step0_outputs( - file_ase="ase_step0_debug.json", - file_ts="torchsim_step0_debug.pkl", - rtol=1e-5, - atol=1e-6, -): - print("\n--- Comparing Step 0 Debug Outputs (ASE JSON vs TorchSim Pickle) --- ") - try: - # Load ASE data from JSON - with open(file_ase) as f: - data_ase = json.load(f, cls=MontyDecoder) - # Load TorchSim data from Pickle - with open(file_ts, "rb") as f: # Use 'rb' for pickle - data_ts = pickle.load(f) - except FileNotFoundError as e: - print(f"Error: Could not find file {e.filename}") - return - except Exception as e: - print(f"Error loading JSON/Pickle files: {e}") - return - - # Basic checks - if data_ase.get("error") or data_ts.get("error"): - print("Comparison aborted due to error during data generation.") - print(f" ASE Error: {data_ase.get('error')}") - print(f" TS Error: {data_ts.get('error')}") - return - - if data_ase.get("step") != 0 or data_ts.get("step") != 0: - print("Warning: One or both files do not contain step 0 data.") - # Continue comparison anyway - - if data_ase.get("image_index_intermediate") != data_ts.get( - "image_index_intermediate" - ): - print("Warning: JSON files are for different intermediate image indices.") - # Continue comparison anyway - - outputs_ase = data_ase.get("outputs", {}) - outputs_ts = data_ts.get("outputs", {}) - - all_keys = set(outputs_ase.keys()) | set(outputs_ts.keys()) - mismatches = 0 - print( - f"Comparing fields for intermediate image index: {data_ase.get('image_index_intermediate', 'N/A')}" + f"{idx:5d} {torch_energy: .8f} {ase_energy: .8f} " + f"{abs(torch_energy - ase_energy):.3e}" ) - - for key in sorted(list(all_keys)): - val_ase = outputs_ase.get(key) - val_ts = outputs_ts.get(key) - - if key not in outputs_ts: - print(f" - Key '{key}': Present in ASE, Missing in TorchSim") - mismatches += 1 - continue - if key not in outputs_ase: - print(f" - Key '{key}': Missing in ASE, Present in TorchSim") - mismatches += 1 - continue - - # --- Handle Type Conversion for Comparison --- - val_ase_comp = val_ase - val_ts_comp = val_ts - - # Convert torch tensor from pickle to numpy/scalar for comparison - if isinstance(val_ts_comp, torch.Tensor): - if val_ts_comp.ndim == 0: # Scalar tensor - val_ts_comp = val_ts_comp.item() - else: - val_ts_comp = val_ts_comp.detach().cpu().numpy() # Use detach() - # -------------------------------------------- - - # --- Debug Print for Specific Key --- - if key == "neb_force_before_climb_vector": - print( - f" DEBUG compare [{key}]: ASE[0]={np.array(val_ase_comp)[0]}, TS[0]={np.array(val_ts_comp)[0]}" - ) - # ------------------------------------ - - # --- Special Handling for imax index --- - if key == "imax": - # ASE imax is index in full list (1 to n_images-1) - # TS imax is index in intermediates (0 to n_images-2) - # Compare ASE imax with TS imax + 1 - ase_imax = int(val_ase_comp) - ts_imax_plus_1 = int(val_ts_comp) + 1 - match = ase_imax == ts_imax_plus_1 - if not match: - difference_info = f"ASE imax={ase_imax}, TS imax(adj)={ts_imax_plus_1}" - status = "Match" if match else "DIFFER" - print(f" - Key '{key:<30}': {status} {difference_info}") - if not match: - mismatches += 1 - continue # Skip rest of comparison for imax - # ------------------------------------- - - # Try numerical comparison first - match = False - difference_info = "" - try: - # Ensure they are numpy arrays for consistent comparison - # ASE data might already be numpy or list, TS data was converted above - arr_ase = np.array(val_ase_comp) - arr_ts = np.array(val_ts_comp) - - if arr_ase.shape != arr_ts.shape: - match = False - difference_info = f"Shapes differ: ASE={arr_ase.shape}, TS={arr_ts.shape}" - elif np.issubdtype(arr_ase.dtype, np.number) and np.issubdtype( - arr_ts.dtype, np.number - ): - match = np.allclose(arr_ase, arr_ts, rtol=rtol, atol=atol) - if not match: - max_abs_diff = np.max(np.abs(arr_ase - arr_ts)) - difference_info = f"Max abs diff: {max_abs_diff:.6e}" - elif arr_ase.dtype == np.bool_ and arr_ts.dtype == np.bool_: - match = np.array_equal(arr_ase, arr_ts) - if not match: - difference_info = f"Boolean values differ: ASE={arr_ase}, TS={arr_ts}" - else: # Fallback for other types (e.g., strings if they were arrays) - match = np.array_equal(arr_ase, arr_ts) - if not match: - difference_info = "Non-numerical array values differ" - - except (TypeError, ValueError): - # Fallback to direct comparison for non-array types or incompatible arrays - try: - if isinstance(val_ase_comp, (float, int)) and isinstance( - val_ts_comp, (float, int) - ): - match = np.isclose(val_ase_comp, val_ts_comp, rtol=rtol, atol=atol) - if not match: - difference_info = f"Diff: {abs(val_ase_comp - val_ts_comp):.6e}" - elif type(val_ase_comp) == type(val_ts_comp): - match = val_ase_comp == val_ts_comp - if not match: - difference_info = ( - f"Values differ: ASE='{val_ase_comp}', TS='{val_ts_comp}'" - ) - else: - # Types should ideally match after conversion, but check just in case - match = False - difference_info = f"Types differ after conversion: ASE={type(val_ase_comp)}, TS={type(val_ts_comp)}" - except Exception: - match = False - - status = "Match" if match else "DIFFER" # Pad DIFFER for alignment - print(f" - Key '{key:<30}': {status} {difference_info}") - if not match: - mismatches += 1 - - if mismatches == 0: - print("\nAll compared output fields match.") - else: - print(f"\nFound {mismatches} mismatch(es) in output fields.") - print("--- End Comparison --- ") - - -# ------------------------------------------------- - - -# --- Add Function to Print Pickle Structure --- -def print_pickle_structure(filename="torchsim_step0_debug.pkl"): - print(f"\n--- Structure of Pickle File: {filename} --- ") - try: - with open(filename, "rb") as f: - data = pickle.load(f) - except FileNotFoundError: - print(f"Error: File not found: {filename}") - return - except Exception as e: - print(f"Error loading pickle file: {e}") - return - - if not isinstance(data, dict): - print(f"Loaded data is not a dictionary (Type: {type(data)})") - return - - print(f"Keys: {list(data.keys())}") - for key, value in data.items(): - if isinstance(value, dict): - print(f" {key}:") - for subkey, subvalue in value.items(): - val_type = type(subvalue) - val_shape = getattr(subvalue, "shape", "N/A") - # Add dtype for tensors - val_dtype = getattr(subvalue, "dtype", "N/A") - print( - f" - {subkey:<30}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}" - ) - else: - val_type = type(value) - val_shape = getattr(value, "shape", "N/A") - val_dtype = getattr(value, "dtype", "N/A") - print(f" {key:<32}: Type={val_type}, Shape={val_shape}, Dtype={val_dtype}") - print("--- End Pickle Structure --- ") - - -# -------------------------------------------- - -# --- Perform manual ASE force calculation for step 0 --- -debug_ase_img_index = ( - n_intermediate_images_ase // 2 + 1 -) # Index in the full list (0 to n_images+1) -calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) -# ------------------------------------------------------ - -print("\nStarting torch-sim NEB optimization...") -final_path_gd = neb_workflow.run( - initial_system=initial_system, - final_system=final_system, - max_steps=100, # Keep increased steps for now - fmax=0.05, -) -print("Finished torch-sim NEB optimization.") - -# Check if it converged and plot results -results = ts_mace_model(final_path_gd) - -energies = results["energy"].tolist() - -# Including the energies from the ASE NEB calculation for comparison -# ase_energies = [0.0, 0.154541015625, 0.6151123046875, 0.8592529296875, 0.8148193359375, 0.5965576171875, 0.47705078125] - -ase_neb_calc = ase_neb(relaxed_start_atoms, relaxed_end_atoms, nimages=5) -ase_energies = [image.get_potential_energy() for image in ase_neb_calc.images] -scaled_ase_energies = [e - ase_energies[0] for e in ase_energies] - - -scaled_energies = [e - energies[0] for e in energies] - -print(scaled_energies) -torch_sim_barrier = max(scaled_energies) - scaled_energies[0] -ase_barrier = max(scaled_ase_energies) - scaled_ase_energies[0] - -# Create normalized reaction coordinates (0 to 1) for both datasets -torch_sim_coords = np.linspace(0, 1, len(scaled_energies)) -ase_coords = np.linspace(0, 1, len(scaled_ase_energies)) - -# Create a common x-axis with 100 points for smoother plotting -common_coords = np.linspace(0, 1, 100) - -# Interpolate both energy profiles to the common coordinate system -torch_sim_interp = np.interp(common_coords, torch_sim_coords, scaled_energies) -ase_interp = np.interp(common_coords, ase_coords, scaled_ase_energies) - -# --- Print Pickle Structure to Verify --- -# print_pickle_structure() -# ------------------------------------- - -# --- Compare Step 0 Debug Outputs for compute_tangent at step 0--- -# compare_step0_outputs() # Use the updated function name -# ------------------------------------ - - -# --- Plot the energy profiles --- -plt.plot(common_coords, torch_sim_interp, label="torch-sim") -plt.plot(common_coords, ase_interp, label="ASE") -plt.xlabel("Reaction Coordinate") -plt.ylabel("Energy (eV)") -plt.title( - f"ASE Barrier = {ase_barrier:.4f} eV, torch-sim Barrier = {torch_sim_barrier:.4f} eV, Difference = {torch_sim_barrier - ase_barrier:.4f} eV" -) -plt.legend() -plt.show() -# ------------------------------------ - - -# --- Function to Inspect HDF5 File Structure --- -def inspect_hdf5(filename): - print(f"\n--- Inspecting HDF5 File: {filename} ---") - try: - with h5py.File(filename, "r") as f: - - def print_attrs(name, obj): - print(f" Path: /{name}") - if isinstance(obj, h5py.Dataset): - print(" Type: Dataset") - print(f" Shape: {obj.shape}") - print(f" Dtype: {obj.dtype}") - # Optionally print a small slice of data - # try: - # print(f" Data sample: {obj[0:min(2, obj.shape[0])]}") - # except Exception as e: - # print(f" Could not read data sample: {e}") - elif isinstance(obj, h5py.Group): - print(" Type: Group") - print(f" Attributes: {dict(obj.attrs)}") - - f.visititems(print_attrs) - except FileNotFoundError: - print(f"Error: File not found: {filename}") - except Exception as e: - print(f"Error inspecting HDF5 file: {e}") - print("--- End HDF5 Inspection ---") - - -# ---------------------------------------------- - - -# --- Analyze Optimizer Convergence --- -def analyze_convergence(ts_traj_file, ase_fmax_csv_file): - print("\n--- Analyzing Optimizer Convergence ---") - max_force_ts = [] - max_force_ase = [] - - # Analyze torch-sim trajectory - try: - with h5py.File(ts_traj_file, "r") as f: - if "data/neb_forces" not in f or "data/image_indices" not in f: - raise ValueError( - "HDF5 file missing '/data/neb_forces' or '/data/image_indices' datasets." - ) - - # Data is under /data group, steps are the first dimension - neb_forces_dset = f["/data/neb_forces"] - image_indices_dset = f["/data/image_indices"] - - n_steps = neb_forces_dset.shape[0] - # Read static image indices (take the first slice) - image_indices = image_indices_dset[0, :] - - # Infer dimensions - n_images_total = len(np.unique(image_indices)) - n_atoms_total = len(image_indices) - if neb_forces_dset.shape[1] != n_atoms_total: - raise ValueError( - f"Mismatch between image_indices length ({n_atoms_total}) and neb_forces second dimension ({neb_forces_dset.shape[1]})" - ) - - n_atoms_per_image = n_atoms_total // n_images_total - print( - f"TorchSim Traj: {n_steps} steps, {n_images_total} total images, {n_atoms_per_image} atoms/image." - ) - - for step in range(n_steps): - # Access forces for the current step from the first dimension - neb_forces = torch.from_numpy(neb_forces_dset[step, :, :]) - - # Select forces only for intermediate images (index 1 to n_images_total - 2) - intermediate_mask = (image_indices > 0) & ( - image_indices < n_images_total - 1 - ) - forces_intermediate = neb_forces[intermediate_mask] - if forces_intermediate.numel() > 0: - max_comp = torch.max(torch.abs(forces_intermediate)).item() - max_force_ts.append(max_comp) - else: - max_force_ts.append(0.0) # Or handle error/empty case - - except Exception as e: - print(f"Error reading torch-sim trajectory {ts_traj_file}: {e}") - - # Read ASE fmax data from CSV - try: - # Use numpy.loadtxt to read the 2nd column (index 1) from the CSV - # Assuming tab delimiter, skipping header row - ase_data = np.loadtxt(ase_fmax_csv_file, delimiter="\t", skiprows=1, usecols=(1,)) - max_force_ase = ase_data.tolist() # Convert numpy array to list - print(f"Read {len(max_force_ase)} fmax values from {ase_fmax_csv_file}") - except Exception as e: - print(f"Error reading ASE fmax CSV file {ase_fmax_csv_file}: {e}") - - # Plotting - if max_force_ts or max_force_ase: - plt.figure() - if max_force_ts: - plt.plot( - range(len(max_force_ts)), - max_force_ts, - label="torch-sim (ase_fire)", - marker=".", - ) - if max_force_ase: - plt.plot( - range(len(max_force_ase)), max_force_ase, label="ASE (FIRE)", marker="." - ) - plt.xlabel("Optimization Step") - plt.ylabel("Max Abs Force Component (eV/Ang)") - plt.title("Optimizer Convergence Comparison") - plt.legend() - plt.grid(True) - plt.yscale("log") # Log scale often helpful for forces - plt.show() - else: - print("No force data extracted to plot convergence.") - - -# inspect_hdf5(traj_file_name) -analyze_convergence(traj_file_name, "ase_fmax_convergence.csv") -# --------------------------------- - -# --- Debugging Functions (Keep for reference) --- -# def calculate_ase_neb_force_step0(...): ... -# def compare_step0_outputs(...): ... -# def print_pickle_structure(...): ... - - -# --- Call Step 0 Debug Functions (Commented out) --- -# # Perform manual ASE force calculation for step 0 -debug_ase_img_index = n_intermediate_images_ase // 2 + 1 -calculate_ase_neb_force_step0(ase_neb_compare, debug_ase_img_index, neb_workflow) - -# # Print Pickle Structure to Verify -print_pickle_structure() - -# # Compare Step 0 Debug Outputs -compare_step0_outputs() -# -------------------------------------------------- +print(f"Barrier difference: {abs(torch_final.max() - ase_final.max()):.3e} eV") + +common_steps = min(len(torch_fmax), len(ase_fmax)) +step_axis = np.arange(common_steps) +final_energy_residual = torch_final - ase_final +force_residual = np.array(torch_fmax[:common_steps]) - np.array(ase_fmax[:common_steps]) + +fig, axes = plt.subplots(2, 2, figsize=(10, 7)) +axes[0, 0].plot(reaction_coordinate, torch_final, "o-", label="torch-sim") +axes[0, 0].plot(reaction_coordinate, ase_final, "s--", label="ASE") +axes[0, 0].set_ylabel("Relative energy") +axes[0, 0].set_title("Final NEB profile") +axes[0, 0].legend() + +axes[0, 1].plot(torch_fmax, label="torch-sim") +axes[0, 1].plot(ase_fmax, label="ASE") +axes[0, 1].axhline(fmax, color="k", linestyle=":", label="fmax") +axes[0, 1].set_ylabel("Max NEB force") +axes[0, 1].set_yscale("log") +axes[0, 1].set_title("Convergence") +axes[0, 1].legend() + +axes[1, 0].axhline(0.0, color="k", linewidth=0.8) +axes[1, 0].plot(reaction_coordinate, final_energy_residual, "o-") +axes[1, 0].set_xlabel("Reaction coordinate") +axes[1, 0].set_ylabel("TS - ASE") +axes[1, 0].set_title("Final energy residual") + +axes[1, 1].axhline(0.0, color="k", linewidth=0.8) +axes[1, 1].plot(step_axis, force_residual) +axes[1, 1].set_xlabel("Optimization step") +axes[1, 1].set_ylabel("TS - ASE") +axes[1, 1].set_title("Max-force residual") + +fig.tight_layout() +fig.savefig("neb_ase_torchsim_comparison.png", dpi=200) +print("Saved comparison plot to neb_ase_torchsim_comparison.png") diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 3f419f688..78b37e151 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -162,6 +162,36 @@ def test_fire_optimization( ) +@pytest.mark.parametrize("fire_flavor", get_args(FireFlavor)) +def test_fire_uses_group_scoped_adaptive_state( + ar_double_sim_state: SimState, lj_model: ModelInterface, fire_flavor: FireFlavor +) -> None: + ar_double_sim_state.group_idx = torch.zeros( + ar_double_sim_state.n_systems, + device=ar_double_sim_state.device, + dtype=torch.int64, + ) + + state = ts.fire_init( + ar_double_sim_state, + lj_model, + fire_flavor=fire_flavor, + dt_start=0.1, + alpha_start=0.1, + ) + + assert state.n_groups == 1 + assert state.dt.shape == (1,) + assert state.alpha.shape == (1,) + assert state.n_pos.shape == (1,) + + updated = ts.fire_step(state=state, model=lj_model, dt_max=0.3) + + assert updated.dt.shape == (1,) + assert updated.alpha.shape == (1,) + assert updated.n_pos.shape == (1,) + + def test_bfgs_optimization( ar_supercell_sim_state: SimState, lj_model: ModelInterface ) -> None: diff --git a/tests/test_state.py b/tests/test_state.py index 7c73a6498..787892441 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -37,11 +37,36 @@ def test_get_attrs_for_scope(si_sim_state: SimState) -> None: per_atom_attrs = dict(get_attrs_for_scope(si_sim_state, "per-atom")) assert set(per_atom_attrs) == {"positions", "masses", "atomic_numbers", "system_idx"} per_system_attrs = dict(get_attrs_for_scope(si_sim_state, "per-system")) - assert set(per_system_attrs) == {"cell"} + assert set(per_system_attrs) == {"cell", "group_idx"} + per_group_attrs = dict(get_attrs_for_scope(si_sim_state, "per-group")) + assert set(per_group_attrs) == set() global_attrs = dict(get_attrs_for_scope(si_sim_state, "global")) assert set(global_attrs) == {"pbc", "_rng"} +def test_group_idx_defaults_to_one_group_per_system( + si_double_sim_state: SimState, +) -> None: + assert torch.equal( + si_double_sim_state.group_idx, + torch.arange(si_double_sim_state.n_systems, device=si_double_sim_state.device), + ) + assert si_double_sim_state.n_groups == si_double_sim_state.n_systems + + +def test_slice_remaps_group_idx(si_double_sim_state: SimState) -> None: + state = si_double_sim_state.clone() + state.group_idx = torch.zeros(state.n_systems, device=state.device, dtype=torch.int64) + + sliced = _slice_state(state, [1, 0]) + + assert torch.equal( + sliced.group_idx, + torch.zeros(sliced.n_systems, device=sliced.device, dtype=torch.int64), + ) + assert sliced.n_groups == 1 + + def test_all_attributes_must_be_specified_in_scopes() -> None: """Test that an error is raised when we forget to specify the scope for an attribute in a child SimState class.""" diff --git a/tests/workflows/test_neb.py b/tests/workflows/test_neb.py new file mode 100644 index 000000000..e7db3e24e --- /dev/null +++ b/tests/workflows/test_neb.py @@ -0,0 +1,161 @@ +import numpy as np +import torch +from ase import Atoms +from ase.mep import NEB as ASENEB +from ase.mep.neb import ImprovedTangentMethod, NEBState + +import torch_sim as ts +from tests.conftest import DEVICE, DTYPE +from torch_sim.models.interface import ModelInterface +from torch_sim.workflows.neb import NEB, calculate_neb_forces, interpolate_path + + +class HarmonicModel(ModelInterface): + def __init__(self) -> None: + super().__init__() + self._device = DEVICE + self._dtype = DTYPE + self._compute_forces = True + self._compute_stress = True + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + per_atom_energy = 0.5 * (state.positions**2).sum(dim=1) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": -state.positions, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +def _single_atom_state(position: float) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([[position, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), + masses=torch.ones(1, device=DEVICE, dtype=DTYPE), + cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18], device=DEVICE), + system_idx=torch.zeros(1, device=DEVICE, dtype=torch.long), + ) + + +def test_interpolate_path_uses_movable_images_only() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + + path = interpolate_path(initial, final, n_images=3) + + assert path.n_systems == 3 + assert path.n_groups == 1 + assert torch.equal(path.group_idx, torch.zeros(3, device=DEVICE, dtype=torch.long)) + assert torch.allclose( + path.positions[:, 0], + torch.tensor([0.25, 0.5, 0.75], device=DEVICE, dtype=DTYPE), + ) + + +def test_calculate_neb_forces_matches_ase_step0_components() -> None: + n_images = 5 + n_atoms = 2 + spring_constant = 0.1 + positions = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 0.8, 0.0]], + [[0.2, 0.1, 0.0], [0.1, 0.9, 0.0]], + [[0.5, 0.25, 0.0], [0.2, 1.0, 0.1]], + [[0.8, 0.35, 0.0], [0.25, 1.05, 0.2]], + [[1.0, 0.5, 0.0], [0.4, 1.2, 0.3]], + ], + device=DEVICE, + dtype=DTYPE, + ) + energies = torch.tensor([0.0, 0.2, 0.7, 0.4, 0.1], device=DEVICE, dtype=DTYPE) + true_forces = torch.tensor( + [ + [[0.1, -0.2, 0.0], [0.0, 0.3, -0.1]], + [[-0.2, 0.1, 0.2], [0.2, -0.1, 0.0]], + [[0.3, 0.0, -0.1], [-0.1, 0.2, 0.1]], + ], + device=DEVICE, + dtype=DTYPE, + ) + path_state = ts.SimState( + positions=positions.reshape(-1, 3), + masses=torch.ones(n_images * n_atoms, device=DEVICE, dtype=DTYPE), + cell=torch.eye(3, device=DEVICE, dtype=DTYPE).unsqueeze(0).repeat(n_images, 1, 1) + * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18, 18], device=DEVICE).repeat(n_images), + system_idx=torch.repeat_interleave( + torch.arange(n_images, device=DEVICE), repeats=n_atoms + ), + ) + + torch_forces = calculate_neb_forces( + path_state, + true_forces.reshape(-1, 3), + energies[1:-1], + energies[0], + energies[-1], + spring_constant=spring_constant, + use_climbing_image=True, + ).reshape(n_images - 2, n_atoms, 3) + + ase_images = [ + Atoms( + "Ar2", + positions=image_positions.detach().cpu().numpy(), + cell=np.eye(3) * 10.0, + pbc=False, + ) + for image_positions in positions + ] + ase_neb = ASENEB(ase_images, k=spring_constant, climb=True, method="improvedtangent") + ase_state = NEBState(ase_neb, ase_neb.images, energies.detach().cpu().numpy()) + tangent_method = ImprovedTangentMethod(ase_neb) + ase_forces = [] + true_forces_np = true_forces.detach().cpu().numpy() + for image_index in range(1, n_images - 1): + spring1 = ase_state.spring(image_index - 1) + spring2 = ase_state.spring(image_index) + tangent = tangent_method.get_tangent(ase_state, spring1, spring2, image_index) + tangent_norm = np.linalg.norm(tangent) + if tangent_norm > 1e-15: + tangent = tangent / tangent_norm + force = true_forces_np[image_index - 1] + force_dot_tangent = np.vdot(force, tangent) + if ase_neb.climb and image_index == ase_state.imax: + ase_forces.append(force - 2 * force_dot_tangent * tangent) + else: + spring_force = (spring2.nt * spring2.k - spring1.nt * spring1.k) * tangent + ase_forces.append(force - force_dot_tangent * tangent + spring_force) + + assert torch.allclose( + torch_forces, + torch.tensor(np.array(ase_forces), device=DEVICE, dtype=DTYPE), + atol=1e-12, + rtol=1e-12, + ) + + +def test_neb_run_uses_single_chain_optimize_without_moving_endpoints() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + neb = NEB( + model=HarmonicModel(), + n_images=1, + optimizer_type="gd", + optimizer_params={"pos_lr": 0.1}, + ) + + result = neb.run(initial, final, max_steps=3, fmax=1e-12) + + assert result.n_systems == 3 + assert torch.allclose( + result.positions[:, 0], + torch.tensor([0.0, 0.5, 1.0], device=DEVICE, dtype=DTYPE), + ) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9671ed973..2033fb509 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -23,7 +23,7 @@ import logging from collections.abc import Callable, Iterator, Sequence from itertools import chain -from typing import Any, get_args +from typing import Any, cast, get_args import torch @@ -110,6 +110,7 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: else: is_tuple_list = False + bins: list[Any] if isinstance(items, dict): # get keys and values (weights) keys = list(items) @@ -455,6 +456,28 @@ def calculate_memory_scalers( ) +def _split_autobatch_units[T: SimState](state: T) -> list[T]: + """Split a state into independent optimizer groups for autobatching.""" + if state.n_groups == state.n_systems and torch.equal( + state.group_idx, + torch.arange(state.n_systems, device=state.device, dtype=torch.int64), + ): + return state.split() + return [ + state[torch.where(state.group_idx == group_idx)[0]] + for group_idx in range(state.n_groups) + ] + + +def _unit_memory_scaler( + state: SimState, + memory_scales_with: MemoryScaling, + cutoff: float, +) -> float: + """Estimate memory for one autobatching unit, which may contain many systems.""" + return float(sum(calculate_memory_scalers(state, memory_scales_with, cutoff))) + + def estimate_max_memory_scaler( states: SimState | Sequence[SimState], model: ModelInterface, @@ -494,13 +517,14 @@ def estimate_max_memory_scaler( The returned value will be the minimum of the two estimates. """ metric_values = torch.tensor(metric_values) + units = _split_autobatch_units(states) if isinstance(states, SimState) else states # select one state with the min n_atoms min_metric = metric_values.min() max_metric = metric_values.max() - min_state = states[int(metric_values.argmin())] - max_state = states[int(metric_values.argmax())] + min_state = units[int(metric_values.argmin())] + max_state = units[int(metric_values.argmax())] print( # noqa: T201 "Model Memory Estimation: Estimating memory from worst case of " @@ -644,9 +668,11 @@ def load_states(self, states: T | Sequence[T]) -> float: batched = ( states if isinstance(states, SimState) else ts.concatenate_states(states) ) - self.memory_scalers = calculate_memory_scalers( - batched, self.memory_scales_with, self.cutoff - ) + units = _split_autobatch_units(batched) + self.memory_scalers = [ + _unit_memory_scaler(unit, self.memory_scales_with, self.cutoff) + for unit in units + ] if not self.max_memory_scaler: self.max_memory_scaler = estimate_max_memory_scaler( batched, @@ -676,12 +702,14 @@ def load_states(self, states: T | Sequence[T]) -> float: ) # list[dict[original_index: int, memory_scale:float]] # Convert to list of lists of indices self.index_bins = [list(batch.keys()) for batch in index_bins] - self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] + self.batched_states = [ + [units[idx] for idx in index_bin] for index_bin in self.index_bins + ] self.current_state_bin = 0 logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", - len(self.memory_scalers), + batched.n_systems, len(self.index_bins), self.max_memory_scaler, ) @@ -794,7 +822,7 @@ def restore_original_order(self, batched_states: Sequence[T]) -> list[T]: ordered_results = batcher.restore_original_order(results) """ - all_states = [state.split() for state in batched_states] + all_states = [_split_autobatch_units(state) for state in batched_states] all_states = list(chain.from_iterable(all_states)) original_indices = list(chain.from_iterable(self.index_bins)) @@ -950,9 +978,12 @@ def load_states(self, states: Sequence[T] | Iterator[T] | T) -> float | None: This method resets the current state indices and completed state tracking, so any ongoing processing will be restarted when this method is called. """ + state_units: Sequence[T] | Iterator[T] if isinstance(states, SimState): - states = states.split() - self.states_iterator = iter(states) + state_units = _split_autobatch_units(cast("T", states)) + else: + state_units = states + self.states_iterator = iter(state_units) self.current_scalers = [] self.current_idx = [] @@ -978,9 +1009,7 @@ def _get_next_states(self) -> list[T]: new_idx: list[int] = [] new_states: list[T] = [] for state in self.states_iterator: - metric = calculate_memory_scalers( - state, self.memory_scales_with, self.cutoff - )[0] + metric = _unit_memory_scaler(state, self.memory_scales_with, self.cutoff) if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] raise ValueError( f"State {metric=} is greater than max_metric {self.max_memory_scaler}" @@ -1036,7 +1065,9 @@ def _get_first_batch(self) -> T: # we need to sample a state and use it to estimate the max metric # for the first batch first_state = next(self.states_iterator) - first_metric = calculate_memory_scalers(first_state, self.memory_scales_with)[0] + first_metric = _unit_memory_scaler( + first_state, self.memory_scales_with, self.cutoff + ) self.current_scalers += [first_metric] self.current_idx += [0] self.iteration_count.append(0) # Initialize attempt counter for first state @@ -1158,9 +1189,21 @@ def next_batch( # noqa: C901 # Force convergence for states that have reached max attempts convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003 - completed_idx = torch.where(convergence_tensor)[0].tolist() - - completed_states = updated_state.pop(completed_idx) + completed_idx = [] + completed_system_indices = [] + for group_idx in range(updated_state.n_groups): + system_mask = updated_state.group_idx == group_idx + if convergence_tensor[system_mask].all(): + completed_idx.append(group_idx) + completed_system_indices.extend(torch.where(system_mask)[0].tolist()) + + completed_states = ( + _split_autobatch_units(updated_state[completed_system_indices]) + if completed_system_indices + else [] + ) + if completed_system_indices: + updated_state.pop(completed_system_indices) # necessary to ensure states that finish at the same time are ordered properly completed_states.reverse() diff --git a/torch_sim/optimizers/cell_filters.py b/torch_sim/optimizers/cell_filters.py index 8a58f8d40..b82701fad 100644 --- a/torch_sim/optimizers/cell_filters.py +++ b/torch_sim/optimizers/cell_filters.py @@ -462,9 +462,12 @@ class CellFireState(CellOptimState, FireState): _system_attributes = ( CellOptimState._system_attributes # noqa: SLF001 - | FireState._system_attributes # noqa: SLF001 | {"cell_velocities"} ) + _group_attributes = ( + CellOptimState._group_attributes # noqa: SLF001 + | FireState._group_attributes # noqa: SLF001 + ) @dataclass(kw_only=True) diff --git a/torch_sim/optimizers/fire.py b/torch_sim/optimizers/fire.py index b4e23b95e..aa02dbf7e 100644 --- a/torch_sim/optimizers/fire.py +++ b/torch_sim/optimizers/fire.py @@ -18,6 +18,13 @@ from torch_sim.optimizers.cell_filters import CellFilter, CellFilterFuncs +def _group_sum( + values: torch.Tensor, group_idx: torch.Tensor, n_groups: int +) -> torch.Tensor: + summed = torch.zeros(n_groups, device=values.device, dtype=values.dtype) + return summed.scatter_add(0, group_idx, values) + + @dcite("10.1103/PhysRevLett.97.170201") def fire_init( state: SimState, @@ -60,7 +67,7 @@ def fire_init( device: torch.device = model.device dtype: torch.dtype = model.dtype - n_systems = state.n_systems + n_groups = state.n_groups # Get initial forces and energy from model model_output = model(state) @@ -72,12 +79,20 @@ def fire_init( dt_start_t = torch.as_tensor(dt_start, device=device, dtype=dtype) if dt_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill - dt_start_t = dt_start_t.expand(n_systems).clone() + dt_start_t = dt_start_t.expand(n_groups).clone() + elif dt_start_t.numel() != n_groups: + raise ValueError( + f"dt_start must have {n_groups} values, got {dt_start_t.numel()}" + ) alpha_start_t = torch.as_tensor(alpha_start, device=device, dtype=dtype) if alpha_start_t.ndim == 0: # NOTE: clone needed as this is overwritten/assigned later by a masked_fill - alpha_start_t = alpha_start_t.expand(n_systems).clone() + alpha_start_t = alpha_start_t.expand(n_groups).clone() + elif alpha_start_t.numel() != n_groups: + raise ValueError( + f"alpha_start must have {n_groups} values, got {alpha_start_t.numel()}" + ) # FIRE-specific additional attributes fire_attrs = { @@ -89,7 +104,7 @@ def fire_init( ), "dt": dt_start_t, "alpha": alpha_start_t, - "n_pos": torch.zeros((n_systems,), device=model.device, dtype=torch.int32), + "n_pos": torch.zeros((n_groups,), device=model.device, dtype=torch.int32), } if cell_filter is not None: # Create cell optimization state @@ -189,7 +204,9 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 eps: float, ) -> T: """Perform one Velocity-Verlet based FIRE optimization step.""" - n_systems, device, dtype = state.n_systems, state.device, state.dtype + n_systems, n_groups = state.n_systems, state.n_groups + device, dtype = state.device, state.dtype + atom_group_idx = state.group_idx[state.system_idx] # Initialize velocities if NaN nan_velocities = state.velocities.isnan().any(dim=1) @@ -200,11 +217,11 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 state.cell_velocities[nan_cell_vel] = 0 alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype + (n_groups,), alpha_start.item(), device=device, dtype=dtype ) # First half of velocity update - atom_wise_dt = state.dt[state.system_idx].unsqueeze(-1) + atom_wise_dt = state.dt[atom_group_idx].unsqueeze(-1) state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) # Position update @@ -227,15 +244,19 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 # Second half of velocity update state.velocities += 0.5 * atom_wise_dt * state.forces / state.masses.unsqueeze(-1) if isinstance(state, CellFireState): - cell_wise_dt = state.dt.view(n_systems, 1, 1) + cell_wise_dt = state.dt[state.group_idx].view(n_systems, 1, 1) state.cell_velocities += ( 0.5 * cell_wise_dt * state.cell_forces / state.cell_masses.unsqueeze(-1) ) # Calculate power - system_power = tsm.batched_vdot(state.forces, state.velocities, state.system_idx) + system_power = tsm.batched_vdot(state.forces, state.velocities, atom_group_idx) if isinstance(state, CellFireState): - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += _group_sum( + (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 @@ -252,32 +273,44 @@ def _vv_fire_step[T: "FireState | CellFireState"]( # noqa: PLR0915 # Velocity mixing v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx + state.velocities, state.velocities, atom_group_idx ) - f_scaling_system = tsm.batched_vdot(state.forces, state.forces, state.system_idx) + f_scaling_system = tsm.batched_vdot(state.forces, state.forces, atom_group_idx) if isinstance(state, CellFireState): - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += _group_sum( + state.cell_velocities.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + f_scaling_system += _group_sum( + state.cell_forces.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_scaling_cell = torch.sqrt( + v_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) + f_scaling_cell = torch.sqrt( + f_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + alpha_cell_bc = state.alpha[state.group_idx].view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), + pos_mask_system[state.group_idx].view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[atom_group_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[atom_group_idx].unsqueeze(-1)) v_mixing_atom = state.forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + alpha_atom = state.alpha[atom_group_idx].unsqueeze(-1) state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), + pos_mask_system[atom_group_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) @@ -301,7 +334,9 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 """Perform one ASE-style FIRE optimization step.""" from torch_sim.optimizers import CellFireState - n_systems, device, dtype = state.n_systems, state.device, state.dtype + n_systems, n_groups = state.n_systems, state.n_groups + device, dtype = state.device, state.dtype + atom_group_idx = state.group_idx[state.system_idx] # Per-atom NaN detection before zeroing: needed to decide whether to skip # FIRE mixing (all NaN = first step) vs run it (partial NaN = autobatcher swap). @@ -315,7 +350,7 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 forces = state.forces else: alpha_start_system = torch.full( - (n_systems,), alpha_start.item(), device=device, dtype=dtype + (n_groups,), alpha_start.item(), device=device, dtype=dtype ) # Transform forces for cell optimization @@ -331,9 +366,13 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 forces = state.forces # Calculate power (newly zeroed systems will have power=0 → neg_mask) - system_power = tsm.batched_vdot(forces, state.velocities, state.system_idx) + system_power = tsm.batched_vdot(forces, state.velocities, atom_group_idx) if isinstance(state, CellFireState): - system_power += (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)) + system_power += _group_sum( + (state.cell_forces * state.cell_velocities).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) # Update dt, alpha, n_pos pos_mask_system = system_power > 0.0 @@ -350,55 +389,74 @@ def _ase_fire_step[T: "FireState | CellFireState"]( # noqa: C901, PLR0915 # Velocity mixing BEFORE acceleration (ASE ordering) v_scaling_system = tsm.batched_vdot( - state.velocities, state.velocities, state.system_idx + state.velocities, state.velocities, atom_group_idx ) - f_scaling_system = tsm.batched_vdot(forces, forces, state.system_idx) + f_scaling_system = tsm.batched_vdot(forces, forces, atom_group_idx) if isinstance(state, CellFireState): - v_scaling_system += state.cell_velocities.pow(2).sum(dim=(1, 2)) - f_scaling_system += state.cell_forces.pow(2).sum(dim=(1, 2)) + v_scaling_system += _group_sum( + state.cell_velocities.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + f_scaling_system += _group_sum( + state.cell_forces.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) - v_scaling_cell = torch.sqrt(v_scaling_system.view(n_systems, 1, 1)) - f_scaling_cell = torch.sqrt(f_scaling_system.view(n_systems, 1, 1)) + v_scaling_cell = torch.sqrt( + v_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) + f_scaling_cell = torch.sqrt( + f_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) v_mixing_cell = state.cell_forces / (f_scaling_cell + eps) * v_scaling_cell - alpha_cell_bc = state.alpha.view(n_systems, 1, 1) + alpha_cell_bc = state.alpha[state.group_idx].view(n_systems, 1, 1) state.cell_velocities = torch.where( - pos_mask_system.view(n_systems, 1, 1), + pos_mask_system[state.group_idx].view(n_systems, 1, 1), (1.0 - alpha_cell_bc) * state.cell_velocities + alpha_cell_bc * v_mixing_cell, torch.zeros_like(state.cell_velocities), ) - v_scaling_atom = torch.sqrt(v_scaling_system[state.system_idx].unsqueeze(-1)) - f_scaling_atom = torch.sqrt(f_scaling_system[state.system_idx].unsqueeze(-1)) + v_scaling_atom = torch.sqrt(v_scaling_system[atom_group_idx].unsqueeze(-1)) + f_scaling_atom = torch.sqrt(f_scaling_system[atom_group_idx].unsqueeze(-1)) v_mixing_atom = forces * (v_scaling_atom / (f_scaling_atom + eps)) - alpha_atom = state.alpha[state.system_idx].unsqueeze(-1) + alpha_atom = state.alpha[atom_group_idx].unsqueeze(-1) state.velocities = torch.where( - pos_mask_system[state.system_idx].unsqueeze(-1), + pos_mask_system[atom_group_idx].unsqueeze(-1), (1.0 - alpha_atom) * state.velocities + alpha_atom * v_mixing_atom, torch.zeros_like(state.velocities), ) # Acceleration (single forward-Euler, no mass for ASE FIRE) - state.velocities += forces * state.dt[state.system_idx].unsqueeze(-1) - dr_atom = state.velocities * state.dt[state.system_idx].unsqueeze(-1) - dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, state.system_idx) + state.velocities += forces * state.dt[atom_group_idx].unsqueeze(-1) + dr_atom = state.velocities * state.dt[atom_group_idx].unsqueeze(-1) + dr_scaling_system = tsm.batched_vdot(dr_atom, dr_atom, atom_group_idx) if isinstance(state, CellFireState): - state.cell_velocities += state.cell_forces * state.dt.view(n_systems, 1, 1) - dr_cell = state.cell_velocities * state.dt.view(n_systems, 1, 1) - - dr_scaling_system += dr_cell.pow(2).sum(dim=(1, 2)) - dr_scaling_cell = torch.sqrt(dr_scaling_system).view(n_systems, 1, 1) + cell_wise_dt = state.dt[state.group_idx].view(n_systems, 1, 1) + state.cell_velocities += state.cell_forces * cell_wise_dt + dr_cell = state.cell_velocities * cell_wise_dt + + dr_scaling_system += _group_sum( + dr_cell.pow(2).sum(dim=(1, 2)), + state.group_idx, + n_groups, + ) + dr_scaling_cell = torch.sqrt( + dr_scaling_system[state.group_idx].view(n_systems, 1, 1) + ) dr_cell = torch.where( dr_scaling_cell > max_step, max_step * dr_cell / (dr_scaling_cell + eps), dr_cell, ) - dr_scaling_atom = torch.sqrt(dr_scaling_system)[state.system_idx].unsqueeze(-1) + dr_scaling_atom = torch.sqrt(dr_scaling_system)[atom_group_idx].unsqueeze(-1) dr_atom = torch.where( dr_scaling_atom > max_step, max_step * dr_atom / (dr_scaling_atom + eps), diff --git a/torch_sim/optimizers/state.py b/torch_sim/optimizers/state.py index fb09795d0..2b0bc480b 100644 --- a/torch_sim/optimizers/state.py +++ b/torch_sim/optimizers/state.py @@ -48,7 +48,7 @@ class FireState(OptimState): n_pos: torch.Tensor _atom_attributes = OptimState._atom_attributes | {"velocities"} # noqa: SLF001 - _system_attributes = OptimState._system_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 + _group_attributes = OptimState._group_attributes | {"dt", "alpha", "n_pos"} # noqa: SLF001 @dataclass(kw_only=True) diff --git a/torch_sim/state.py b/torch_sim/state.py index 6584480a3..1c3d67998 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -146,6 +146,8 @@ class SimState: atomic_numbers (torch.Tensor): Atomic numbers with shape (n_atoms,) system_idx (torch.Tensor): Maps each atom index to its system index. Has shape (n_atoms,), must be unique consecutive integers starting from 0. + group_idx (torch.Tensor): Maps each system index to its optimizer group index. + Has shape (n_systems,). Defaults to one group per system. constraints (list["Constraint"] | None): List of constraints applied to the system. Constraints affect degrees of freedom and modify positions. @@ -180,6 +182,7 @@ class SimState: pbc: torch.Tensor # coerced from bool/list[bool] by __setattr__ atomic_numbers: torch.Tensor system_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ + group_idx: torch.Tensor = field(default=None) # type: ignore[assignment] # coerced from None by __setattr__ _constraints: list["Constraint"] = field(default_factory=list) _system_extras: dict[str, torch.Tensor] = field(default_factory=dict) _atom_extras: dict[str, torch.Tensor] = field(default_factory=dict) @@ -196,6 +199,7 @@ def __init__( # noqa: D107 pbc: torch.Tensor | list[bool] | bool, atomic_numbers: torch.Tensor, system_idx: torch.Tensor | None = None, + group_idx: torch.Tensor | None = None, _constraints: list[Constraint] | None = None, _rng: PRNGLike = None, **kwargs: Any, @@ -207,7 +211,8 @@ def __init__( # noqa: D107 "atomic_numbers", "system_idx", } - _system_attributes: ClassVar[set[str]] = {"cell"} + _system_attributes: ClassVar[set[str]] = {"cell", "group_idx"} + _group_attributes: ClassVar[set[str]] = set() _global_attributes: ClassVar[set[str]] = {"pbc", "_rng"} @property @@ -251,6 +256,13 @@ def __setattr__(self, name: str, value: object) -> None: # noqa: C901 _, counts = torch.unique_consecutive(value, return_counts=True) if not torch.all(counts == torch.bincount(value)): raise ValueError("System indices must be unique consecutive integers") + elif name == "group_idx": + if isinstance(value, torch.Tensor): + value = value.to(dtype=torch.int64) + if value.ndim != 1: + raise ValueError("Group indices must be a 1D tensor") + if value.numel() > 0 and value.min() < 0: + raise ValueError("Group indices must be non-negative") super().__setattr__(name, value) def __post_init__(self) -> None: # noqa: C901 @@ -270,6 +282,10 @@ def __post_init__(self) -> None: # noqa: C901 # Get n_systems from system_idx (now guaranteed to be non-None) _, counts = torch.unique_consecutive(self.system_idx, return_counts=True) n_systems = len(counts) + if self.group_idx is None: + self.group_idx = torch.arange( + n_systems, device=self.device, dtype=torch.int64 + ) if self.constraints: validate_constraints(self.constraints, state=self) @@ -285,6 +301,17 @@ def __post_init__(self) -> None: # noqa: C901 f"Cell must have shape (n_systems={n_systems}, 3, 3), " f"got {self.cell.shape}" ) + if self.group_idx.shape[0] != n_systems: + raise ValueError( + f"Group indices must have shape (n_systems={n_systems},), " + f"got {self.group_idx.shape}" + ) + unique_groups = torch.unique(self.group_idx) + expected_groups = torch.arange( + unique_groups.numel(), device=self.device, dtype=self.group_idx.dtype + ) + if not torch.equal(unique_groups, expected_groups): + raise ValueError("Group indices must be consecutive integers starting from 0") # if devices aren't all the same, raise an error, in a clean way devices = { @@ -296,6 +323,7 @@ def __post_init__(self) -> None: # noqa: C901 "atomic_numbers", "pbc", "system_idx", + "group_idx", ) } if len(set(devices.values())) > 1: @@ -313,6 +341,14 @@ def __post_init__(self) -> None: # noqa: C901 f"System extra '{key}' leading dim must be " f"n_systems={n_systems}, got {val.shape[0]}" ) + for name, val in get_attrs_for_scope(self, "per-group"): + if not isinstance(val, torch.Tensor): + continue + if val.shape[0] != self.n_groups: + raise ValueError( + f"Group attribute '{name}' leading dim must be " + f"n_groups={self.n_groups}, got {val.shape[0]}" + ) for key, val in self._atom_extras.items(): if key in all_attrs or hasattr(type(self), key): raise ValueError(f"Atom extra '{key}' shadows an existing attribute") @@ -330,6 +366,7 @@ def _get_all_attributes(cls) -> set[str]: return ( cls._atom_attributes | cls._system_attributes + | cls._group_attributes | cls._global_attributes | {"_constraints", "_system_extras", "_atom_extras"} ) @@ -439,6 +476,11 @@ def n_systems(self) -> int: """Number of systems in the system.""" return torch.unique(self.system_idx).shape[0] + @property + def n_groups(self) -> int: + """Number of optimizer groups in the state.""" + return int(self.group_idx.max().item()) + 1 if self.group_idx.numel() else 0 + @property def volume(self) -> torch.Tensor: """Volume of the system.""" @@ -636,6 +678,7 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: all_known = cls._get_all_attributes() n_atoms = state.n_atoms n_systems = state.n_systems + n_groups = state.n_groups for key, val in additional_attrs.items(): if key in all_known: attrs[key] = val @@ -645,6 +688,8 @@ def from_state(cls, state: "SimState", **additional_attrs: Any) -> Self: attrs.setdefault("_atom_extras", {})[key] = val elif leading == n_systems: attrs.setdefault("_system_extras", {})[key] = val + elif leading == n_groups: + attrs[key] = val else: raise ValueError(f"Attribute '{key}' has invalid leading dimension") else: @@ -778,7 +823,7 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: # exceptions exist because the type hint doesn't actually reflect the real type # (since we change their type in the post_init) - exceptions = {"system_idx"} + exceptions = {"system_idx", "group_idx"} type_hints = typing.get_type_hints(cls) for attr_name, attr_type_hint in type_hints.items(): @@ -803,13 +848,19 @@ def _assert_no_tensor_attributes_can_be_none(cls) -> None: @classmethod def _assert_all_attributes_have_defined_scope(cls) -> None: all_defined_attributes = ( - cls._atom_attributes | cls._system_attributes | cls._global_attributes + cls._atom_attributes + | cls._system_attributes + | cls._group_attributes + | cls._global_attributes ) # 1) assert that no attribute is defined twice in all_defined_attributes duplicates = ( (cls._atom_attributes & cls._system_attributes) + | (cls._atom_attributes & cls._group_attributes) | (cls._atom_attributes & cls._global_attributes) + | (cls._system_attributes & cls._group_attributes) | (cls._system_attributes & cls._global_attributes) + | (cls._group_attributes & cls._global_attributes) ) if duplicates: raise TypeError( @@ -1004,13 +1055,13 @@ def _state_to_device[T: SimState]( # noqa: C901 def get_attrs_for_scope( - state: SimState, scope: Literal["per-atom", "per-system", "global"] + state: SimState, scope: Literal["per-atom", "per-system", "per-group", "global"] ) -> Generator[tuple[str, Any], None, None]: """Get attributes for a given scope. Args: state (SimState): The state to get attributes for - scope (Literal["per-atom", "per-system", "global"]): The scope to get + scope (Literal["per-atom", "per-system", "per-group", "global"]): Scope to get attributes for Returns: @@ -1021,6 +1072,8 @@ def get_attrs_for_scope( attr_names = state._atom_attributes # noqa: SLF001 case "per-system": attr_names = state._system_attributes # noqa: SLF001 + case "per-group": + attr_names = state._group_attributes # noqa: SLF001 case "global": attr_names = state._global_attributes # noqa: SLF001 case _: @@ -1034,7 +1087,23 @@ def get_attrs_for_scope( yield from state.atom_extras.items() -def _filter_attrs_by_index( +def _remap_indices_by_first_occurrence( + indices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Remap integer indices to consecutive values in first-occurrence order.""" + unique_indices: list[int] = [] + remapped: list[int] = [] + for idx in indices.tolist(): + if idx not in unique_indices: + unique_indices.append(idx) + remapped.append(unique_indices.index(idx)) + return ( + torch.tensor(remapped, device=indices.device, dtype=torch.int64), + torch.tensor(unique_indices, device=indices.device, dtype=torch.int64), + ) + + +def _filter_attrs_by_index( # noqa: C901 state: SimState, atom_indices: torch.Tensor, system_indices: torch.Tensor, @@ -1060,12 +1129,17 @@ def _filter_attrs_by_index( atom_remap[atom_indices] = torch.arange(len(atom_indices), device=state.device) if len(system_indices) == 0: system_remap = torch.empty(0, device=state.device, dtype=torch.long) + old_group_indices = torch.empty(0, device=state.device, dtype=torch.long) + new_group_idx = torch.empty(0, device=state.device, dtype=torch.long) else: max_idx = int(system_indices.max().item()) + 1 system_remap = torch.empty(max_idx, device=state.device, dtype=torch.long) system_remap[system_indices] = torch.arange( len(system_indices), device=state.device ) + new_group_idx, old_group_indices = _remap_indices_by_first_occurrence( + state.group_idx[system_indices] + ) # select_constraint uses boolean masks (which lose ordering), so we must # remap constraint atom_idx / system_idx afterward to match the actual @@ -1097,8 +1171,16 @@ def _filter_attrs_by_index( for name, val in get_attrs_for_scope(state, "per-system"): if name in state.system_extras: continue + if name == "group_idx": + filtered_attrs[name] = new_group_idx + else: + filtered_attrs[name] = ( + val[system_indices] if isinstance(val, torch.Tensor) else val + ) + + for name, val in get_attrs_for_scope(state, "per-group"): filtered_attrs[name] = ( - val[system_indices] if isinstance(val, torch.Tensor) else val + val[old_group_indices] if isinstance(val, torch.Tensor) else val ) filtered_attrs["_system_extras"] = { @@ -1134,13 +1216,28 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 split_per_system = {} for attr_name, attr_value in get_attrs_for_scope(state, "per-system"): - if attr_name in state.system_extras: + if attr_name == "group_idx" or attr_name in state.system_extras: continue if isinstance(attr_value, torch.Tensor): split_per_system[attr_name] = torch.split(attr_value, 1, dim=0) else: # Non-tensor attributes are replicated for each split split_per_system[attr_name] = [attr_value] * state.n_systems + split_per_group = {} + for attr_name, attr_value in get_attrs_for_scope(state, "per-group"): + if isinstance(attr_value, torch.Tensor): + split_per_group[attr_name] = [ + attr_value[ + int(state.group_idx[sys_idx].item()) : int( + state.group_idx[sys_idx].item() + ) + + 1 + ] + for sys_idx in range(state.n_systems) + ] + else: + split_per_group[attr_name] = [attr_value] * state.n_systems + global_attrs = dict(get_attrs_for_scope(state, "global")) split_system_extras: dict[str, list[torch.Tensor]] = {} @@ -1168,6 +1265,7 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 "system_idx": torch.zeros( system_sizes[sys_idx], device=state.device, dtype=torch.int64 ), + "group_idx": torch.zeros(1, device=state.device, dtype=torch.int64), # Add the split per-atom attributes **{ attr_name: split_per_atom[attr_name][sys_idx] @@ -1175,6 +1273,10 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 }, # Add the split per-system attributes (with unpadding applied) **per_system_dict, + **{ + attr_name: split_per_group[attr_name][sys_idx] + for attr_name in split_per_group + }, # Add the global attributes **global_attrs, "_system_extras": { @@ -1329,10 +1431,13 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Pre-allocate lists for tensors to concatenate per_atom_tensors = defaultdict(list) per_system_tensors = defaultdict(list) + per_group_tensors = defaultdict(list) system_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) atom_extras_tensors: dict[str, list[torch.Tensor]] = defaultdict(list) new_system_indices = [] + new_group_indices = [] system_offset = 0 + group_offset = 0 num_atoms_per_state = [] # Process all states in a single pass @@ -1350,10 +1455,13 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 # Collect per-system properties for prop, val in get_attrs_for_scope(state, "per-system"): - if prop in state.system_extras: + if prop == "group_idx" or prop in state.system_extras: continue per_system_tensors[prop].append(val) + for prop, val in get_attrs_for_scope(state, "per-group"): + per_group_tensors[prop].append(val) + # Collect extras for key, val in state.system_extras.items(): system_extras_tensors[key].append(val) @@ -1364,9 +1472,11 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 num_systems = state.n_systems new_indices = state.system_idx + system_offset new_system_indices.append(new_indices) + new_group_indices.append(state.group_idx + group_offset) num_atoms_per_state.append(state.n_atoms) system_offset += num_systems + group_offset += state.n_groups # Concatenate collected tensors for prop, tensors in per_atom_tensors.items(): @@ -1427,8 +1537,16 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915 else: # Non-tensor attributes, take first one (they should all be identical) concatenated[prop] = tensors[0] + for prop, tensors in per_group_tensors.items(): + concatenated[prop] = ( + torch.cat(tensors, dim=0) + if isinstance(tensors[0], torch.Tensor) + else tensors[0] + ) + # Concatenate system indices concatenated["system_idx"] = torch.cat(new_system_indices) + concatenated["group_idx"] = torch.cat(new_group_indices) # Concatenate extras concatenated["_system_extras"] = { diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py index f40f9dc03..b4007a2f8 100644 --- a/torch_sim/workflows/neb.py +++ b/torch_sim/workflows/neb.py @@ -1,89 +1,59 @@ -"""Nudged Elastic Band (NEB) workflow. - -This module implements the Nudged Elastic Band method for finding minimum energy -paths between two given atomic configurations. -""" +"""Nudged Elastic Band (NEB) workflow.""" import inspect import logging -import os # Import os for fsync -import pickle # Import pickle from collections.abc import Callable -from contextlib import nullcontext from dataclasses import dataclass, field +from functools import partial from typing import Any, Literal import torch from torch_sim.models.interface import ModelInterface from torch_sim.optimizers import ( - CellFireState, - FireState, OptimState, fire_init, fire_step, gradient_descent_init, gradient_descent_step, ) -from torch_sim.optimizers.cell_filters import CellFilter +from torch_sim.runners import optimize from torch_sim.state import SimState, concatenate_states, initialize_state -from torch_sim.trajectory import TorchSimTrajectory from torch_sim.transforms import minimum_image_displacement from torch_sim.typing import StateLike logger = logging.getLogger(__name__) -# Add epsilon for numerical stability _EPS = torch.finfo(torch.float64).eps +OptimizerType = Literal["fire", "gd", "ase_fire"] + def _extract_kwargs_from_params( params: dict[str, Any], func: Callable[..., Any], exclude: set[str] | None = None ) -> dict[str, Any]: - """Extract kwargs from params dict that match function signature. - - Args: - params: Dictionary of parameters to filter - func: Function to extract parameters for - exclude: Set of parameter names to exclude (e.g., 'state', 'model') - - Returns: - Dictionary of parameters that match the function signature - """ + """Return the entries in ``params`` accepted by ``func``.""" exclude = exclude or {"state", "model"} sig = inspect.signature(func) return {k: v for k, v in params.items() if k in sig.parameters and k not in exclude} -@dataclass +@dataclass(frozen=True) class _OptimizerConfig: - """Configuration for an optimizer type.""" + """Functional optimizer pair and argument modifiers.""" - init_fn: Callable[..., Any] - step_fn: Callable[..., Any] - state_type: type + init_fn: Callable[..., OptimState] + step_fn: Callable[..., OptimState] init_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None step_kwargs_modifier: Callable[[dict[str, Any]], dict[str, Any]] | None = None -# Registry of optimizer configurations -_OPTIMIZER_REGISTRY: dict[str, _OptimizerConfig] = { - "fire": _OptimizerConfig( - init_fn=fire_init, - step_fn=fire_step, - state_type=FireState, - ), - "frechet_cell_fire": _OptimizerConfig( - init_fn=fire_init, - step_fn=fire_step, - state_type=CellFireState, - init_kwargs_modifier=lambda kwargs: {**kwargs, "cell_filter": CellFilter.frechet}, - ), +_OPTIMIZER_REGISTRY: dict[OptimizerType, _OptimizerConfig] = { + "fire": _OptimizerConfig(init_fn=fire_init, step_fn=fire_step), "gd": _OptimizerConfig( init_fn=gradient_descent_init, step_fn=gradient_descent_step, - state_type=OptimState, step_kwargs_modifier=lambda kwargs: ( kwargs if "pos_lr" in kwargs else {**kwargs, "pos_lr": kwargs.get("lr", 0.01)} ), @@ -91,7 +61,6 @@ class _OptimizerConfig: "ase_fire": _OptimizerConfig( init_fn=fire_init, step_fn=fire_step, - state_type=FireState, init_kwargs_modifier=lambda kwargs: ( kwargs if "fire_flavor" in kwargs else {**kwargs, "fire_flavor": "ase_fire"} ), @@ -102,585 +71,336 @@ class _OptimizerConfig: } +def validate_endpoints(initial_state: SimState, final_state: SimState) -> None: + """Validate that endpoints define a fixed-cell single-chain NEB path.""" + if initial_state.n_systems != 1 or final_state.n_systems != 1: + raise ValueError("Initial and final states must each contain one system.") + if initial_state.n_atoms != final_state.n_atoms: + raise ValueError( + f"Initial ({initial_state.n_atoms}) and final ({final_state.n_atoms}) " + "states must have the same number of atoms." + ) + if not torch.equal(initial_state.atomic_numbers, final_state.atomic_numbers): + raise ValueError("Initial and final states must have the same atom types.") + if not torch.equal(initial_state.pbc, final_state.pbc): + raise ValueError("Initial and final states must have the same PBC setting.") + if not torch.allclose(initial_state.cell, final_state.cell): + raise ValueError("Fixed-cell NEB requires matching endpoint cells.") + + +def interpolate_path( + initial_state: SimState, final_state: SimState, n_images: int +) -> SimState: + """Linearly interpolate movable NEB images using the minimum image convention.""" + validate_endpoints(initial_state, final_state) + if n_images < 1: + raise ValueError("n_images must be at least 1.") + + n_atoms = initial_state.n_atoms + displacement = minimum_image_displacement( + dr=final_state.positions - initial_state.positions, + cell=initial_state.cell[0], + pbc=initial_state.pbc, + ).reshape(n_atoms, 3) + factors = torch.linspace( + 0.0, + 1.0, + steps=n_images + 2, + device=initial_state.device, + dtype=initial_state.dtype, + )[1:-1] + positions = ( + initial_state.positions.unsqueeze(0) + + factors.view(-1, 1, 1) * displacement.unsqueeze(0) + ).reshape(-1, 3) + system_idx = torch.repeat_interleave( + torch.arange(n_images, device=initial_state.device, dtype=torch.int64), + repeats=n_atoms, + ) + return SimState( + positions=positions, + masses=initial_state.masses.repeat(n_images), + cell=initial_state.cell.repeat(n_images, 1, 1), + pbc=initial_state.pbc, + atomic_numbers=initial_state.atomic_numbers.repeat(n_images), + system_idx=system_idx, + group_idx=torch.zeros(n_images, device=initial_state.device, dtype=torch.int64), + ) + + +def as_sim_state(state: SimState) -> SimState: + """Drop optimizer-only fields while preserving the atomistic state.""" + return SimState( + positions=state.positions, + masses=state.masses, + cell=state.cell, + pbc=state.pbc, + atomic_numbers=state.atomic_numbers, + system_idx=state.system_idx, + group_idx=state.group_idx, + _constraints=state.constraints, + ) + + +def assemble_path( + initial_state: SimState, movable_state: SimState, final_state: SimState +) -> SimState: + """Return the full NEB path as endpoints plus movable images.""" + return concatenate_states( + [ + as_sim_state(initial_state), + as_sim_state(movable_state), + as_sim_state(final_state), + ] + ) + + +def compute_tangents( + all_positions: torch.Tensor, + all_energies: torch.Tensor, + cell: torch.Tensor, + *, + pbc: torch.Tensor, +) -> torch.Tensor: + """Compute improved normalized tangents for the intermediate NEB images.""" + n_total_images, n_atoms, _ = all_positions.shape + n_intermediate = n_total_images - 2 + tangents = torch.zeros( + (n_intermediate, n_atoms, 3), + device=all_positions.device, + dtype=all_positions.dtype, + ) + displacements = minimum_image_displacement( + dr=all_positions[1:] - all_positions[:-1], + cell=cell, + pbc=pbc, + ).reshape(n_total_images - 1, n_atoms, 3) + dE_forward = all_energies[1:] - all_energies[:-1] + + for i in range(n_intermediate): + image_idx = i + 1 + dR_plus = displacements[image_idx] + dR_minus = displacements[image_idx - 1] + dE_plus = dE_forward[image_idx] + dE_minus = dE_forward[image_idx - 1] + + if dE_plus > 0 and dE_minus > 0: + tangent = dR_plus + elif dE_plus < 0 and dE_minus < 0: + tangent = dR_minus + else: + abs_dE_plus = torch.abs(dE_plus) + abs_dE_minus = torch.abs(dE_minus) + delta_max = torch.maximum(abs_dE_plus, abs_dE_minus) + delta_min = torch.minimum(abs_dE_plus, abs_dE_minus) + if (dE_plus + dE_minus) > 0: + tangent = dR_plus * delta_max + dR_minus * delta_min + else: + tangent = dR_plus * delta_min + dR_minus * delta_max + + norm = torch.linalg.norm(tangent) + if norm > _EPS: + tangents[i] = tangent / norm + + return tangents + + +def calculate_neb_forces( + path_state: SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + *, + spring_constant: float, + use_climbing_image: bool, +) -> torch.Tensor: + """Calculate NEB forces for the movable images in a single path.""" + n_total_images = path_state.n_systems + n_intermediate = n_total_images - 2 + if n_intermediate <= 0: + raise ValueError("A NEB path must include at least one movable image.") + if true_energies.shape[0] != n_intermediate: + raise ValueError(f"{true_energies.shape[0]=} does not match {n_intermediate=}.") + + n_atoms = path_state.n_atoms // n_total_images + all_positions = path_state.positions.reshape(n_total_images, n_atoms, 3) + all_energies = torch.cat( + [initial_energy.reshape(1), true_energies, final_energy.reshape(1)] + ) + true_forces_by_image = true_forces.reshape(n_intermediate, n_atoms, 3) + cell = path_state.cell[0] + + tangents = compute_tangents( + all_positions, + all_energies, + cell, + pbc=path_state.pbc, + ) + displacements = minimum_image_displacement( + dr=all_positions[1:] - all_positions[:-1], + cell=cell, + pbc=path_state.pbc, + ).reshape(n_total_images - 1, n_atoms, 3) + segment_lengths = torch.linalg.norm(displacements, dim=(-1, -2)) + + true_dot_tangent = (true_forces_by_image * tangents).sum(dim=(-1, -2), keepdim=True) + perpendicular_forces = true_forces_by_image - true_dot_tangent * tangents + spring_magnitude = spring_constant * (segment_lengths[1:] - segment_lengths[:-1]) + spring_forces = spring_magnitude.view(-1, 1, 1) * tangents + neb_forces = perpendicular_forces + spring_forces + + if use_climbing_image: + climbing_idx = int(torch.argmax(true_energies).item()) + neb_forces[climbing_idx] = true_forces_by_image[climbing_idx] - ( + 2 * true_dot_tangent[climbing_idx] * tangents[climbing_idx] + ) + + return neb_forces.reshape(-1, 3) + + +def _endpoint_energies( + initial_state: SimState, final_state: SimState, model: ModelInterface +) -> tuple[torch.Tensor, torch.Tensor]: + output = model( + concatenate_states([as_sim_state(initial_state), as_sim_state(final_state)]) + ) + return output["energy"][0], output["energy"][1] + + +def _store_neb_force_metadata(state: OptimState, neb_forces: torch.Tensor) -> None: + state.forces = neb_forces + state.neb_forces = neb_forces + state.neb_max_force = torch.linalg.norm(neb_forces, dim=-1).max() + + +def neb_init( + state: SimState, + model: ModelInterface, + *, + initial_state: SimState, + final_state: SimState, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + base_init_fn: Callable[..., OptimState], + base_init_kwargs: dict[str, Any] | None = None, + spring_constant: float = 0.1, + use_climbing_image: bool = False, +) -> OptimState: + """Initialize the base optimizer state and replace true forces with NEB forces.""" + opt_state = base_init_fn(state, model, **(base_init_kwargs or {})) + full_path = assemble_path(initial_state, opt_state, final_state) + neb_forces = calculate_neb_forces( + full_path, + opt_state.forces, + opt_state.energy, + initial_energy, + final_energy, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + _store_neb_force_metadata(opt_state, neb_forces) + return opt_state + + +def neb_step( + state: OptimState, + model: ModelInterface, + *, + initial_state: SimState, + final_state: SimState, + initial_energy: torch.Tensor, + final_energy: torch.Tensor, + base_step_fn: Callable[..., OptimState], + base_step_kwargs: dict[str, Any] | None = None, + spring_constant: float = 0.1, + use_climbing_image: bool = False, +) -> OptimState: + """Advance one NEB step by delegating position updates to a base optimizer.""" + state = base_step_fn(state, model, **(base_step_kwargs or {})) + true_forces = state.forces.clone() + full_path = assemble_path(initial_state, state, final_state) + neb_forces = calculate_neb_forces( + full_path, + true_forces, + state.energy, + initial_energy, + final_energy, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + state.true_forces = true_forces + _store_neb_force_metadata(state, neb_forces) + return state + + +def neb_convergence_fn( + state: OptimState, last_energy: torch.Tensor, *, fmax: float +) -> torch.Tensor: + """Return all-or-nothing NEB convergence for the movable images.""" + del last_energy + neb_max_force = getattr( + state, + "neb_max_force", + torch.linalg.norm(state.forces, dim=-1).max(), + ) + converged = neb_max_force < fmax + return torch.full( + (state.n_systems,), + bool(converged.item()), + device=state.device, + dtype=torch.bool, + ) + + @dataclass class NEB: - """Nudged Elastic Band (NEB) optimizer. - - Finds the minimum energy path (MEP) between an initial and final state using - the NEB algorithm. - - Attributes: - model: The energy/force model (e.g., MACE) wrapped in a ModelInterface. - n_images: Number of intermediate images between initial and final states. - spring_constant: Spring constant connecting adjacent images (eV/Ang^2). - use_climbing_image: Whether to use a climbing image. - optimizer_type: Type of optimizer to use. - optimizer_params: Parameters for the chosen optimizer. - trajectory_filename: Optional filename for saving the NEB trajectory. - device: Computation device (e.g., 'cpu', 'cuda'). If None, uses model device. - dtype: Computation data type (e.g., torch.float32). If None, uses model dtype. - """ + """Single-chain Nudged Elastic Band workflow.""" model: ModelInterface n_images: int - spring_constant: float = 0.1 # eV/Ang^2, typical ASE default + spring_constant: float = 0.1 use_climbing_image: bool = False - optimizer_type: Literal["fire", "gd", "frechet_cell_fire", "ase_fire"] = "fire" + optimizer_type: OptimizerType = "ase_fire" optimizer_params: dict[str, Any] = field(default_factory=dict) trajectory_filename: str | None = None device: torch.device | None = None dtype: torch.dtype | None = None def __post_init__(self) -> None: - """Initializes device, dtype, and optimizer functions after dataclass creation.""" + """Initialize device, dtype, and optimizer configuration.""" if self.device is None: self.device = self.model.device if self.dtype is None: self.dtype = self.model.dtype - - # Initialize variable to store step 0 debug output - self._step0_debug_output = None - - # Get optimizer configuration from registry if self.optimizer_type not in _OPTIMIZER_REGISTRY: raise ValueError( - f"Unsupported optimizer_type: {self.optimizer_type}. " - f"Supported types: {list(_OPTIMIZER_REGISTRY.keys())}" + f"Unsupported optimizer_type={self.optimizer_type!r}; expected one of " + f"{list(_OPTIMIZER_REGISTRY)}." ) config = _OPTIMIZER_REGISTRY[self.optimizer_type] - self._init_fn = config.init_fn - self._step_fn = config.step_fn - self._OptimizerStateType = config.state_type - - # Automatically extract kwargs from optimizer_params based on function signatures - # For init: exclude 'state' and 'model' (positional args) - # For step: exclude 'state' and 'model' (positional args) init_kwargs = _extract_kwargs_from_params( self.optimizer_params, config.init_fn, exclude={"state", "model"} ) step_kwargs = _extract_kwargs_from_params( self.optimizer_params, config.step_fn, exclude={"state", "model"} ) - - # Apply modifiers if provided (for special cases like cell_filter, defaults, etc.) - if config.init_kwargs_modifier: + if config.init_kwargs_modifier is not None: init_kwargs = config.init_kwargs_modifier(init_kwargs) - if config.step_kwargs_modifier: + if config.step_kwargs_modifier is not None: step_kwargs = config.step_kwargs_modifier(step_kwargs) + self._init_fn = config.init_fn + self._step_fn = config.step_fn self._init_kwargs = init_kwargs self._step_kwargs = step_kwargs def _interpolate_path( self, initial_state: SimState, final_state: SimState ) -> SimState: - """Linearly interpolate the initial path between states using MIC. - - Generates `n_images` intermediate states between the initial and final states - by linear interpolation of atomic positions, respecting periodic boundary - conditions via the Minimum Image Convention (MIC). - - Args: - initial_state (SimState): The starting SimState (must be single-batch). - final_state (SimState): The ending SimState (must be single-batch). - - Returns: - SimState: A single SimState containing all interpolated intermediate - images, batched together. The batch index corresponds to the image - index (0 to n_images-1). - - Raises: - ValueError: If initial and final states are incompatible (e.g., different - number of atoms, atom types, PBC settings, or if they are not - single-batch states). - """ - # --- Input Validation --- - if initial_state.n_systems != 1 or final_state.n_systems != 1: - raise ValueError("Initial and final states must be single-system SimStates.") - if initial_state.n_atoms != final_state.n_atoms: - raise ValueError( - f"Initial ({initial_state.n_atoms}) and final ({final_state.n_atoms}) " - "states must have the same number of atoms." - ) - if not torch.equal(initial_state.atomic_numbers, final_state.atomic_numbers): - # Comparing floats might be tricky, but atomic numbers should be exact - raise ValueError("Initial and final states must have the same atom types.") - # Compare PBC values properly (can be bool, list, or tensor) - pbc_match = False - if isinstance(initial_state.pbc, torch.Tensor) and isinstance( - final_state.pbc, torch.Tensor - ): - pbc_match = torch.equal(initial_state.pbc, final_state.pbc) - elif isinstance(initial_state.pbc, torch.Tensor) or isinstance( - final_state.pbc, torch.Tensor - ): - # One is tensor, one is not - convert both to tensors for comparison - initial_pbc_tensor = ( - initial_state.pbc - if isinstance(initial_state.pbc, torch.Tensor) - else torch.tensor(initial_state.pbc, device=initial_state.device) - ) - final_pbc_tensor = ( - final_state.pbc - if isinstance(final_state.pbc, torch.Tensor) - else torch.tensor(final_state.pbc, device=final_state.device) - ) - pbc_match = torch.equal(initial_pbc_tensor, final_pbc_tensor) - else: - # Both are bools or lists - pbc_match = initial_state.pbc == final_state.pbc - if not pbc_match: - # TODO: Could potentially support different PBCs, but complex for NEB. - raise ValueError("Initial and final states must have the same PBC setting.") - # For fixed-cell NEB, cells should ideally be identical. Warn if not? - # if not torch.allclose(initial_state.cell, final_state.cell): - - n_atoms_per_image = initial_state.n_atoms - - # --- Interpolation --- - initial_pos = initial_state.positions - final_pos = final_state.positions - - # Calculate displacement using Minimum Image Convention - displacement = minimum_image_displacement( - dr=final_pos - initial_pos, - cell=initial_state.cell[0], # Use cell from initial state - pbc=initial_state.pbc, - ) - # Ensure shape is correct [n_atoms, 3] - displacement = displacement.reshape(n_atoms_per_image, 3) - - # Generate interpolation factors (e.g., for n_images=3: 0.25, 0.5, 0.75) - factors = torch.linspace( - 0.0, 1.0, steps=self.n_images + 2, device=self.device, dtype=self.dtype - )[1:-1] # Exclude 0.0 and 1.0 # Ensure dtype - factors = factors.view(-1, 1, 1) # Shape: [n_images, 1, 1] - - # Calculate interpolated positions: initial + factor * displacement - # Broadcasting: [N_atoms, 3] + [N_images, 1, 1] * [N_atoms, 3] -> [N_images, N_atoms, 3] - interpolated_pos = initial_pos.unsqueeze(0) + factors * displacement.unsqueeze(0) - - # Reshape to [n_images * n_atoms_per_image, 3] - all_positions = interpolated_pos.reshape(-1, 3) - - # --- Create Batched State --- - # Repeat other attributes for each image - all_atomic_numbers = initial_state.atomic_numbers.repeat(self.n_images) - all_masses = initial_state.masses.repeat(self.n_images) - # Use initial state's cell, repeated for each image - all_cells = initial_state.cell.repeat( - self.n_images, 1, 1 - ) # Shape: [n_images, 3, 3] - - # Create system_idx tensor: [0, 0, ..., 1, 1, ..., n_images-1, ...] - system_indices = torch.arange( - self.n_images, device=self.device, dtype=torch.int64 - ) - all_system_idx = torch.repeat_interleave( - system_indices, repeats=n_atoms_per_image - ) - - return SimState( - positions=all_positions, - atomic_numbers=all_atomic_numbers, - masses=all_masses, - cell=all_cells, - pbc=initial_state.pbc, - system_idx=all_system_idx, - ) - - def _compute_tangents( - self, - all_pos: torch.Tensor, # Shape: [n_total_images, n_atoms, 3] - all_energies: torch.Tensor, # Shape: [n_total_images] - cell: torch.Tensor, # Shape: [3, 3] - *, # Make pbc keyword-only - pbc: bool, - ) -> torch.Tensor: - """Compute normalized tangent vectors for intermediate NEB images. - - Implements the improved tangent estimate of Henkelman and Jónsson (2000) - to determine the local tangent direction at each intermediate image based - on the positions and energies of its neighbors. - - Args: - all_pos (torch.Tensor): Atomic configurations for all images in the path - (initial + intermediate + final), shape [n_total_images, n_atoms, 3]. - all_energies (torch.Tensor): Potential energy of each image, shape - [n_total_images]. - cell (torch.Tensor): Unit cell vectors (shape [3, 3]), assumed constant - for the path. - pbc (bool): Flag indicating if periodic boundary conditions are active. - - Returns: - torch.Tensor: Normalized local tangent vectors for the intermediate - images only, shape [n_images, n_atoms, 3]. Tangents are zero for - numerically identical adjacent images. - """ - n_total_images, n_atoms_per_image, _ = all_pos.shape - n_intermediate_images = n_total_images - 2 - device = all_pos.device - dtype = all_pos.dtype - - # Initialize tangents for intermediate images only - tangents = torch.zeros( - (n_intermediate_images, n_atoms_per_image, 3), - device=device, - dtype=self.dtype, # Use self.dtype - ) - - # Calculate displacements between adjacent images using MIC - # dR_forward[i] = R_{i+1} - R_i - displacements = minimum_image_displacement( - dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc - ) - # Ensure shape is correct after MIC if needed - displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) - - # Energy differences V_{i+1} - V_i - dE_forward = all_energies[1:] - all_energies[:-1] # Shape: [n_total_images - 1] - - # Compute tangents for intermediate images (indices 1 to N in all_pos) - for i in range(n_intermediate_images): - img_idx = i + 1 # Index in all_pos, all_energies - - # Displacements adjacent to image `img_idx` - # Note: displacements[k] is R_{k+1} - R_k - dR_plus = displacements[img_idx] # R_{i+1} - R_i (where i = img_idx) - dR_minus = displacements[img_idx - 1] # R_i - R_{i-1} (where i = img_idx) - - # Energy differences adjacent to image `img_idx` - dE_plus = dE_forward[img_idx] # V_{i+1} - V_i - dE_minus = dE_forward[img_idx - 1] # V_i - V_{i-1} - - # Select tangent based on energy profile (Henkelman & Jónsson criteria) - tangent_i = torch.zeros_like(dR_plus) - - # Condition 1: Ascending segment (minimum) V_{i+1}>V_i and V_i>V_{i-1} => dE_plus>0 and dE_minus>0 - if dE_plus > 0 and dE_minus > 0: - tangent_i = ( - dR_plus # ASE uses forward difference (dR_plus = R[i+1] - R[i]) - ) - - # Condition 2: Descending segment (maximum) V_{i+1} dE_plus<0 and dE_minus<0 - elif ( - dE_plus < 0 and dE_minus < 0 - ): # Check if dE_minus comparison is correct (<0 vs >0) - # tangent_i = dR_plus if abs(dE_plus) < abs(dE_minus) else dR_minus # Old complex version - # ASE logic: if E[i+1] < E[i] < E[i-1], tangent = dR_minus (spring1.t) -> Mismatch? - # Let's assume torch-sim should match ASE exactly: - tangent_i = ( - dR_minus # ASE uses backward difference (dR_minus = R[i] - R[i-1]) - ) - - # Condition 3: Other cases (weighted average in ASE) - else: - # Implement ASE's weighting logic precisely - # Note: ASE uses absolute values for deltavmax/min calculation - abs_dE_plus = torch.abs(dE_plus) - abs_dE_minus = torch.abs(dE_minus) - - deltavmax = torch.maximum(abs_dE_plus, abs_dE_minus) - deltavmin = torch.minimum(abs_dE_plus, abs_dE_minus) - - # Check E[i+1] vs E[i-1] - # E[i+1] - E[i-1] = dE_plus + dE_minus - if (dE_plus + dE_minus) > 0: # E[i+1] > E[i-1] - tangent_i = dR_plus * deltavmax + dR_minus * deltavmin - else: # E[i+1] <= E[i-1] - tangent_i = dR_plus * deltavmin + dR_minus * deltavmax - - # Normalize the tangent vector *within* the loop - norm_i = torch.linalg.norm(tangent_i) - if norm_i > _EPS: - tangents[i] = tangent_i / norm_i - # else: tangent remains zero if norm is too small - - return tangents - - def _calculate_neb_forces( - self, - path_state: SimState, - true_forces: torch.Tensor, - true_energies: torch.Tensor, - initial_energy: torch.Tensor, - final_energy: torch.Tensor, - step: int, - ) -> tuple[torch.Tensor, dict | None]: # Return forces and optional debug data - """Calculate the NEB forces for intermediate images. - - The NEB force is composed of the true force perpendicular to the path tangent - and the spring force parallel to the path tangent. Handles climbing image - force modification if enabled. - - Args: - path_state (SimState): SimState containing the full path (initial + - intermediate + final images). Batches are assumed to be ordered. - true_forces (torch.Tensor): Forces from the potential energy model for - the *intermediate* images only, shape [n_movable_atoms, 3]. - true_energies (torch.Tensor): Potential energies for the *intermediate* - images only, shape [n_images]. - initial_energy (torch.Tensor): Potential energy of the initial state - (scalar tensor). - final_energy (torch.Tensor): Potential energy of the final state - (scalar tensor). - step (int): Current optimization step number (used for climbing image delay). - - Returns: - torch.Tensor: Calculated NEB forces for the intermediate images, ready to - be passed to the optimizer, shape [n_movable_atoms, 3]. - """ - n_total_images = path_state.n_systems - n_intermediate_images = n_total_images - 2 - assert n_intermediate_images == self.n_images - n_atoms_per_image = path_state.n_atoms // n_total_images - - # --- Reshape inputs --- - # Positions for all images: [n_total_images, n_atoms, 3] - all_pos = path_state.positions.reshape(n_total_images, n_atoms_per_image, 3) - # True forces for intermediate images: [n_images, n_atoms, 3] - true_forces_reshaped = true_forces.reshape( - n_intermediate_images, n_atoms_per_image, 3 - ) - # Cell vectors (assuming fixed cell for now, take from first batch) - cell = path_state.cell[0] # Shape [3, 3] - # Convert pbc to bool if it's a tensor (for _compute_tangents) - if isinstance(path_state.pbc, torch.Tensor): - pbc_bool: bool = bool( - path_state.pbc.any().item() - ) # True if any dimension has PBC - elif isinstance(path_state.pbc, bool): - pbc_bool = path_state.pbc - elif isinstance(path_state.pbc, list): - pbc_bool = bool(any(path_state.pbc)) - else: - pbc_bool = True - pbc = path_state.pbc # Keep original for minimum_image_displacement - - # --- Get Energies for Tangent Calculation --- - all_energies = torch.cat( - [ - initial_energy.unsqueeze(0), - true_energies, - final_energy.unsqueeze(0), - ] - ) - - # --- Setup for Debugging Step 0 --- - log_step_0 = step == 0 - debug_img_idx = ( - n_intermediate_images // 2 - ) # Index within intermediates (0 to n_images-1) - debug_img_idx_all = debug_img_idx + 1 # Index within all_pos (0 to n_images+1) - debug_data_ts = {} # Initialize debug dict - - if log_step_0: - debug_data_ts = { - "step": 0, - "image_index_intermediate": debug_img_idx, - "image_index_absolute": debug_img_idx_all, - "inputs": {}, - "outputs": {}, - "error": None, - } - debug_data_ts["inputs"]["energies_all"] = all_energies # Monty handles tensor - debug_data_ts["inputs"]["cell"] = cell - debug_data_ts["inputs"]["pbc"] = pbc_bool # Store Python bool - debug_data_ts["inputs"]["positions_image_minus_1"] = all_pos[ - debug_img_idx_all - 1 - ] - debug_data_ts["inputs"]["positions_image"] = all_pos[debug_img_idx_all] - debug_data_ts["inputs"]["positions_image_plus_1"] = all_pos[ - debug_img_idx_all + 1 - ] - debug_data_ts["inputs"]["true_forces_image"] = true_forces_reshaped[ - debug_img_idx - ] - - # --- Calculate Tangents (tau) using the improved method --- - # tangents shape: [n_images, n_atoms, 3] - tangents = self._compute_tangents(all_pos, all_energies, cell, pbc=pbc_bool) - logger.debug( - f" Step {step}: Tangent norms per image: {torch.linalg.norm(tangents, dim=(-1, -2))}" - ) - if log_step_0: - # Note: ASE tangent might not be normalized if norm is ~0, TS tangent should be. - tangent_img = tangents[debug_img_idx] - tangent_norm_img = torch.linalg.norm(tangent_img) - debug_data_ts["outputs"]["tangent_vector"] = tangent_img - debug_data_ts["outputs"]["tangent_norm"] = tangent_norm_img - - # --- Calculate Displacements for Spring Force --- - # Recalculate here or reuse from _compute_tangents if efficient - displacements = minimum_image_displacement( - dr=all_pos[1:] - all_pos[:-1], cell=cell, pbc=pbc - ) - displacements = displacements.reshape(n_total_images - 1, n_atoms_per_image, 3) - if log_step_0: - # Save displacements relevant to the middle image's tangent/spring calculation - debug_data_ts["outputs"]["mic_displacement_1"] = displacements[ - debug_img_idx_all - 1 - ] # R(i) - R(i-1) - debug_data_ts["outputs"]["mic_displacement_2"] = displacements[ - debug_img_idx_all - ] # R(i+1) - R(i) - - # --- Calculate NEB Force Components --- - - # 1. Perpendicular component of true force - # F_perp = F_true - (F_true . tau) * tau - # Dot product (sum over atoms and dims): [n_images] - F_true_dot_tau = (true_forces_reshaped * tangents).sum(dim=(-1, -2), keepdim=True) - F_perp = true_forces_reshaped - F_true_dot_tau * tangents - logger.debug( - f" Step {step}: F_perp norms per image: {torch.linalg.norm(F_perp, dim=(-1, -2))}" - ) - if log_step_0: - f_perp_img = F_perp[debug_img_idx] - f_perp_norm_img = torch.linalg.norm(f_perp_img) - debug_data_ts["outputs"]["f_true_dot_tau"] = F_true_dot_tau[ - debug_img_idx - ].item() # scalar - debug_data_ts["outputs"]["f_perp_vector"] = f_perp_img - debug_data_ts["outputs"]["f_perp_norm"] = f_perp_norm_img - - # 2. Parallel component of spring force - # F_spring_par = k * (|R_{i+1}-R_i| - |R_i-R_{i-1}|) * tau_i - # Segment lengths (scalar magnitude per segment): [n_images+1] - segment_lengths = torch.linalg.norm( - displacements, dim=(-1, -2) - ) # Cleaner way [n_total_images-1] - # Spring force magnitude (scalar per intermediate image): [n_images] - F_spring_mag = self.spring_constant * (segment_lengths[1:] - segment_lengths[:-1]) - # Project onto tangent: [n_images, 1, 1] -> [n_images, n_atoms, 3] - F_spring_par = F_spring_mag.view(-1, 1, 1) * tangents - logger.debug( - f" Step {step}: F_spring_par norms per image: {torch.linalg.norm(F_spring_par, dim=(-1, -2))}" - ) - if log_step_0: - f_spring_par_img = F_spring_par[debug_img_idx] - f_spring_par_norm_img = torch.linalg.norm(f_spring_par_img) - debug_data_ts["outputs"]["segment_lengths"] = segment_lengths # Full list - debug_data_ts["outputs"]["spring_force_magnitude_term"] = F_spring_mag[ - debug_img_idx - ].item() # scalar - debug_data_ts["outputs"]["f_spring_par_vector"] = f_spring_par_img - debug_data_ts["outputs"]["f_spring_par_norm"] = f_spring_par_norm_img - - # --- Combine Components for NEB Force --- - neb_forces = F_perp + F_spring_par - if log_step_0: - # --- Direct Debug Logs for Step 0 --- - f_perp_img = F_perp[debug_img_idx] - f_spring_par_img = F_spring_par[debug_img_idx] - neb_force_img = neb_forces[debug_img_idx] - logger.debug(" --- DIRECT DEBUG LOG (TORCH-SIM STEP 0) ---") - logger.debug(f" f_perp_norm: {torch.linalg.norm(f_perp_img)}") - logger.debug(f" f_perp_vec[0]: {f_perp_img[0]}") - # segment_lengths shape: [n_total_images - 1] - # segment_lengths[debug_img_idx] corresponds to spring2 length - # segment_lengths[debug_img_idx-1] corresponds to spring1 length - len1 = segment_lengths[debug_img_idx - 1] - len2 = segment_lengths[debug_img_idx] - len_diff = len2 - len1 - logger.debug( - f" spring1_length (R[{debug_img_idx_all}]-R[{debug_img_idx_all - 1}]): {len1}" - ) - logger.debug( - f" spring2_length (R[{debug_img_idx_all + 1}]-R[{debug_img_idx_all}]): {len2}" - ) - logger.debug(f" Length Diff (len2 - len1): {len_diff}") - logger.debug(f" f_spring_par_norm: {torch.linalg.norm(f_spring_par_img)}") - logger.debug(f" f_spring_par_vec[0]: {f_spring_par_img[0]}") - logger.debug( - f" neb_force_before_climb_norm: {torch.linalg.norm(neb_force_img)}" - ) - # -------------------------------------- - # Store a *copy* detached from the graph to prevent modification by climbing image logic - debug_data_ts["outputs"]["neb_force_before_climb_vector"] = ( - neb_forces[debug_img_idx].clone().detach() - ) - debug_data_ts["outputs"]["neb_force_before_climb_norm"] = torch.linalg.norm( - neb_forces[debug_img_idx] - ) # Norm calculation is fine - - # --- Log the vector right before it would be saved --- - logger.debug( - f" Value assigned to debug_data[neb_force_before_climb_vector][0]: {neb_forces[debug_img_idx][0]}" - ) - # ----------------------------------------------------- - - # --- Handle Climbing Image --- - climbing_delay_steps = 10 # Example value - if ( - self.use_climbing_image and n_intermediate_images > 0 - ): # and step >= climbing_delay_steps: # Check step number - REMOVED DELAY - # Find index of highest energy image among intermediates - climbing_image_idx = torch.argmax( - true_energies - ).item() # Index from 0 to n_images-1 - # Calculate the climbing force: F_climb = F_true - 2 * (F_true . tau) * tau - F_climb = true_forces_reshaped[climbing_image_idx] - ( - 2 * F_true_dot_tau[climbing_image_idx] * tangents[climbing_image_idx] - ) - # Replace the NEB force for the climbing image with F_climb - # This overwrites the spring force component for this image, as required. - neb_forces[climbing_image_idx] = F_climb - logger.debug( - f" Step {step}: Climbing image index: {climbing_image_idx}, " - f"Climbing Force Norm: {torch.linalg.norm(F_climb)}" - ) - if log_step_0: - debug_data_ts["outputs"]["is_climbing_image"] = ( - climbing_image_idx == debug_img_idx - ) - debug_data_ts["outputs"]["imax"] = climbing_image_idx - debug_data_ts["outputs"]["climbing_force_vector"] = neb_forces[ - climbing_image_idx - ] - debug_data_ts["outputs"]["climbing_force_norm"] = torch.linalg.norm( - neb_forces[climbing_image_idx] - ) - - # --- Logging (Optional) --- - # logger.debug( - # " Max True Force Mag: " - # f"{torch.linalg.norm(true_forces_reshaped, dim=(-1,-2)).max().item():.4f}" - # ) - # logger.debug( - # " Max F_perp Mag: " - # f"{torch.linalg.norm(F_perp, dim=(-1,-2)).max().item():.4f}" - # ) - # logger.debug( - # " Max F_spring_par Mag: " - # f"{torch.linalg.norm(F_spring_par, dim=(-1,-2)).max().item():.4f}" - # ) - # logger.debug( - # " Max NEB Force Mag: " - # f"{torch.linalg.norm(neb_forces, dim=(-1,-2)).max().item():.4f}" - # ) - logger.debug( - f" Step {step}: NEB force norms per image: {torch.linalg.norm(neb_forces, dim=(-1, -2))}" - ) - logger.debug(f" Step {step}: Intermediate energies: {true_energies}") - if log_step_0 and not ( - self.use_climbing_image and climbing_image_idx == debug_img_idx - ): # Avoid logging twice if climbing image was logged - # If not the climbing image, the final force is the one before modification - pass # Already stored neb_force_before_climb - - if log_step_0: - debug_data_ts["outputs"]["final_neb_force_vector"] = neb_forces[debug_img_idx] - debug_data_ts["outputs"]["final_neb_force_norm"] = torch.linalg.norm( - neb_forces[debug_img_idx] - ) - - # --- Reshape output --- - final_neb_forces = neb_forces.reshape(-1, 3) # [n_movable_atoms, 3] - - # Return forces and the debug dictionary if step 0 - return final_neb_forces, debug_data_ts if log_step_0 else None + """Compatibility wrapper for existing callers/tests.""" + return interpolate_path(initial_state, final_state, self.n_images) def run( self, @@ -688,253 +408,56 @@ def run( final_system: StateLike, max_steps: int = 100, fmax: float = 0.05, - # TODO: add convergence criteria, batching options, output frequency etc. ) -> SimState: - """Run the Nudged Elastic Band optimization. - - Optimizes the path between the initial and final systems to find the - Minimum Energy Path (MEP). - - Args: - initial_system (StateLike): The starting configuration (can be ASE Atoms, - SimState, or other compatible format recognized by initialize_state). - final_system (StateLike): The ending configuration. - max_steps (int): Maximum number of optimization steps allowed. - fmax (float): Convergence criterion based on the maximum NEB force component - acting on any single atom across all intermediate images (in eV/Ang). - - Returns: - SimState: The final optimized NEB path, including the initial, - intermediate, and final images, concatenated into a single SimState. - SimState: The final optimized NEB path, including the initial, - intermediate, and final images, concatenated into a single SimState. - """ + """Run a single-chain NEB optimization through ``ts.optimize``.""" logger.info("Starting NEB optimization") - - # Reset step 0 debug output storage for this run - self._step0_debug_output = None - - # 1. Initialize initial and final states initial_state = initialize_state(initial_system, self.device, self.dtype) final_state = initialize_state(final_system, self.device, self.dtype) - # TODO: Add checks (e.g., same number of atoms, atom types) - # Ensure endpoints are single-system SimStates - # (They should already be from initialize_state, but verify) - if initial_state.n_systems != 1: - raise ValueError("Initial state must be a single-system SimState") - if final_state.n_systems != 1: - raise ValueError("Final state must be a single-system SimState") - - # 1b. Calculate endpoint energies/forces (needed for tangent calculation) - # Note: Forces aren't strictly needed here but model usually returns both - logger.info("Calculating endpoint energies...") - # Concatenate expects a list of SimStates (or subclasses) - endpoint_states = concatenate_states([initial_state, final_state]) - endpoint_output = self.model(endpoint_states) - initial_energy = endpoint_output["energy"][0] - final_energy = endpoint_output["energy"][1] - # Distribute model extras (e.g. interaction_energy) back onto the - # endpoint states so that subsequent concatenate_states calls with - # opt_state (which carries those extras) produce consistent leading dims - n_init_atoms = initial_state.n_atoms - n_final_atoms = final_state.n_atoms - init_extras: dict[str, torch.Tensor] = {} - final_extras: dict[str, torch.Tensor] = {} - for key, val in endpoint_output.items(): - if key in {"energy", "forces", "stress"} or not isinstance(val, torch.Tensor): - continue - if val.shape[0] == 2: - init_extras[key] = val[:1] - final_extras[key] = val[1:] - elif val.shape[0] == n_init_atoms + n_final_atoms: - init_extras[key] = val[:n_init_atoms] - final_extras[key] = val[n_init_atoms:] - initial_state.store_model_extras(init_extras) - final_state.store_model_extras(final_extras) + validate_endpoints(initial_state, final_state) + initial_energy, final_energy = _endpoint_energies( + initial_state, final_state, self.model + ) + movable_images = interpolate_path(initial_state, final_state, self.n_images) logger.info( - f"Initial Energy: {initial_energy:.4f}, Final Energy: {final_energy:.4f}" + "Running NEB for max %d steps or fmax < %.4f eV/Ang.", + max_steps, + fmax, ) - # 2. Create initial interpolated path (movable images only) - interpolated_images = self._interpolate_path(initial_state, final_state) - - # 3. Initialize optimizer state for the movable images - # Use the generic initializer with model parameter - opt_state = self._init_fn(interpolated_images, self.model, **self._init_kwargs) - - # 4. Optimization loop - logger.info(f"Running NEB for max {max_steps} steps or fmax < {fmax} eV/Ang.") - - # Context manager for trajectory writing - traj_context = ( - TorchSimTrajectory(self.trajectory_filename, mode="w") - if self.trajectory_filename - else nullcontext() # Use a dummy context if no filename + endpoint_kwargs: dict[str, Any] = { + "initial_state": as_sim_state(initial_state), + "final_state": as_sim_state(final_state), + "initial_energy": initial_energy, + "final_energy": final_energy, + "spring_constant": self.spring_constant, + "use_climbing_image": self.use_climbing_image, + } + trajectory_reporter = ( + {"filenames": self.trajectory_filename} + if self.trajectory_filename is not None + else None ) - - def _opt_state_as_simstate(state: SimState) -> SimState: - """Project an OptimState/FireState down to a plain SimState. - - Concatenating an OptimState/FireState with plain SimState endpoints - collapses to the first state's class (SimState), causing optimizer- - specific fields like velocities/forces/energy to be misrouted into - extras with mismatched leading dims. We strip those here and - preserve only model-derived extras (interaction_energy, etc.) that - were also populated on the endpoints. - """ - optim_only_atom = {"forces"} - optim_only_system = {"energy", "stress", "dt", "alpha", "n_pos"} - sys_extras = { - k: v for k, v in state.system_extras.items() if k not in optim_only_system - } - atom_extras = { - k: v for k, v in state.atom_extras.items() if k not in optim_only_atom - } - return SimState( - positions=state.positions, - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - system_idx=state.system_idx, - _system_extras=sys_extras, - _atom_extras=atom_extras, - ) - - with traj_context as traj: - for step in range(max_steps): - # a. Get current true forces and energies - true_forces = opt_state.forces - true_energies = opt_state.energy - - # b. Calculate NEB forces - # Concatenate states - ensures consistent group ID (0 for single NEB) - full_path_state_calc = concatenate_states( - [initial_state, _opt_state_as_simstate(opt_state), final_state] - ) - # Store true forces *before* calculating NEB forces - true_forces_for_traj = opt_state.forces.clone() - - # Get forces and potentially the step 0 debug data - neb_forces, step0_debug_data = self._calculate_neb_forces( - full_path_state_calc, - true_forces, # Pass the forces from the start of the step - true_energies, - initial_energy, - final_energy, - step=step, - ) - - # c. Update the forces in the FIRE state object with NEB forces - opt_state.forces = neb_forces - neb_forces_for_traj = neb_forces.clone() - - # d. Perform optimization step - # Use the generic step function with model parameter - opt_state = self._step_fn(opt_state, self.model, **self._step_kwargs) - - # *** Store Step 0 Debug Data AFTER optimizer step *** - if step == 0 and step0_debug_data: - logger.info("Storing Step 0 TorchSim debug data.") - self._step0_debug_output = step0_debug_data - # *************************************************** - - # e. Write to trajectory (if enabled) - if self.trajectory_filename is not None: # Use explicit check - # Create the full path state for writing (including endpoints) - current_full_path = concatenate_states( - [ - initial_state, - _opt_state_as_simstate(opt_state), - final_state, - ] - ) - # Write arrays directly using traj.write_arrays - data_to_write = { - "positions": current_full_path.positions, - # Add forces - Need to handle endpoints (no NEB forces) - # Pad NEB forces with zeros for endpoints - "neb_forces": torch.cat( - [ - torch.zeros_like(initial_state.positions), - neb_forces_for_traj, - torch.zeros_like(final_state.positions), - ], - dim=0, - ), - # True forces are only calculated for intermediate images - # Need forces for endpoints from the initial calculation - "true_forces": torch.cat( - [ - endpoint_output["forces"][ - : initial_state.n_atoms - ], # Initial forces - true_forces_for_traj, # Intermediate forces - endpoint_output["forces"][ - initial_state.n_atoms : - ], # Final forces - ], - dim=0, - ), - "energies": torch.cat( - [ - initial_energy.unsqueeze(0), - opt_state.energy, # Energies *after* the step - final_energy.unsqueeze(0), - ], - dim=0, - ), - } - if step == 0: # Write static data only on the first step - # Assuming fixed cell NEB, cell is static - data_to_write["cell"] = current_full_path.cell - # These should also be static for the whole band - data_to_write["atomic_numbers"] = current_full_path.atomic_numbers - data_to_write["masses"] = current_full_path.masses - # Convert bool to tensor for saving - data_to_write["pbc"] = torch.tensor(current_full_path.pbc) - # Save the system_idx tensor to map atoms to images - data_to_write["image_indices"] = current_full_path.system_idx - - traj.write_arrays(data_to_write, steps=step) - - # f. Check convergence - max_force_magnitude = torch.sqrt((neb_forces**2).sum(dim=-1)).max() - max_intermediate_energy = opt_state.energy.max() - logger.info( - f"Step {step + 1:4d}: Max Force = {max_force_magnitude:.4f} Max Energy = {max_intermediate_energy:.4f}" - # f"Energy = {fire_state.energy.mean():.4f} eV (mean per image), " # Removed mean energy for brevity - ) - if max_force_magnitude < fmax: - logger.info("NEB optimization converged.") - break - else: # Loop finished without break - logger.warning("NEB optimization did not converge within max_steps.") - - # 5. Return the final path (including endpoints) - # --- Write Step 0 Debug Dictionary AFTER loop finishes --- - if self._step0_debug_output: - output_filename_ts = "torchsim_step0_debug.pkl" # Change extension - logger.info( - f"Attempting to write final Step 0 TorchSim debug data to {output_filename_ts}" - ) - try: - with open(output_filename_ts, "wb") as f: # Use 'wb' for pickle - pickle.dump(self._step0_debug_output, f) - f.flush() - os.fsync(f.fileno()) - logger.info( - f"--- TorchSim NEB Debug Info (Step 0) saved to {output_filename_ts} ---" - ) - except Exception as e: - logger.error( - f"ERROR WRITING FINAL TORCHSIM STEP 0 DEBUG PICKLE: {e}", - exc_info=True, - ) - else: - logger.warning("No Step 0 TorchSim debug data was stored to write.") - # ---------------------------------------------------------- - - return concatenate_states( - [initial_state, _opt_state_as_simstate(opt_state), final_state] + opt_state = optimize( + movable_images, + self.model, + optimizer=(neb_init, neb_step), + convergence_fn=partial(neb_convergence_fn, fmax=fmax), + max_steps=max_steps, + steps_between_swaps=1, + trajectory_reporter=trajectory_reporter, + autobatcher=False, + init_kwargs={ + **endpoint_kwargs, + "base_init_fn": self._init_fn, + "base_init_kwargs": self._init_kwargs, + }, + **endpoint_kwargs, + base_step_fn=self._step_fn, + base_step_kwargs=self._step_kwargs, ) + final_neb_max_force = torch.linalg.norm(opt_state.forces, dim=-1).max() + if final_neb_max_force >= fmax: + logger.warning("NEB optimization did not converge within max_steps.") + else: + logger.info("NEB optimization converged.") + return assemble_path(initial_state, opt_state, final_state) From 816837a2333b7d4fdff5f740c10181401cae0df4 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 27 May 2026 09:55:48 -0400 Subject: [PATCH 5/7] add batched neb tutorial --- examples/scripts/10_fire.py | 207 -------- examples/scripts/9_neb.py | 305 ------------ examples/tutorials/nudged_elastic_band.py | 573 ++++++++++++++++++++++ 3 files changed, 573 insertions(+), 512 deletions(-) delete mode 100644 examples/scripts/10_fire.py delete mode 100644 examples/scripts/9_neb.py create mode 100644 examples/tutorials/nudged_elastic_band.py diff --git a/examples/scripts/10_fire.py b/examples/scripts/10_fire.py deleted file mode 100644 index 4c0fa5061..000000000 --- a/examples/scripts/10_fire.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Compare plain ASE FIRE and torch-sim ase_fire on one analytic system.""" -# ruff: noqa: D101, D102, D103, D107 - -# %% -# /// script -# dependencies = [ -# "ase", -# "matplotlib", -# ] -# /// - -from dataclasses import dataclass -from typing import ClassVar - -import matplotlib.pyplot as plt -import numpy as np -import torch -from ase import Atoms -from ase.calculators.calculator import Calculator, all_changes -from ase.optimize import FIRE - -import torch_sim as ts -from torch_sim.models.interface import ModelInterface - - -@dataclass(frozen=True) -class PotentialParams: - valley_scale: float = 5.0 - valley_curve: float = 0.5 - - -def energy_forces( - positions: torch.Tensor, params: PotentialParams -) -> tuple[torch.Tensor, torch.Tensor]: - """Return per-atom energies and forces for a curved double well.""" - x = positions[:, 0] - y = positions[:, 1] - z = positions[:, 2] - u = x**2 - 1.0 - v = y - params.valley_curve * u - energy = u**2 + params.valley_scale * v**2 + z**2 - dE_dx = 4.0 * x * u - 4.0 * params.valley_scale * params.valley_curve * x * v - dE_dy = 2.0 * params.valley_scale * v - dE_dz = 2.0 * z - forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) - return energy, forces - - -class TorchModel(ModelInterface): - def __init__(self, params: PotentialParams) -> None: - super().__init__() - self._device = torch.device("cpu") - self._dtype = torch.float64 - self._compute_forces = True - self._compute_stress = True - self.params = params - - def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: - del kwargs - per_atom_energy, forces = energy_forces(state.positions, self.params) - energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) - energy.scatter_add_(0, state.system_idx, per_atom_energy) - return { - "energy": energy, - "forces": forces, - "stress": torch.zeros( - state.n_systems, 3, 3, device=state.device, dtype=state.dtype - ), - } - - -class ASECalculator(Calculator): - implemented_properties: ClassVar[list[str]] = ["energy", "forces"] - - def __init__(self, params: PotentialParams) -> None: - super().__init__() - self.params = params - - def calculate( - self, - atoms: Atoms | None = None, - properties: list[str] | None = None, - system_changes: list[str] = all_changes, - ) -> None: - super().calculate(atoms, properties, system_changes) - positions = torch.tensor(self.atoms.positions, dtype=torch.float64) - per_atom_energy, forces = energy_forces(positions, self.params) - self.results["energy"] = float(per_atom_energy.sum()) - self.results["forces"] = forces.detach().cpu().numpy() - - -def make_state(position: tuple[float, float, float]) -> ts.SimState: - return ts.SimState( - positions=torch.tensor([position], dtype=torch.float64), - masses=torch.ones(1, dtype=torch.float64), - cell=torch.eye(3, dtype=torch.float64).unsqueeze(0) * 10.0, - pbc=False, - atomic_numbers=torch.tensor([18]), - system_idx=torch.zeros(1, dtype=torch.long), - ) - - -def run_torch_fire( - state: ts.SimState, model: ModelInterface, *, steps: int, fmax: float -) -> tuple[ts.SimState, list[float], list[float]]: - energy_history: list[float] = [] - fmax_history: list[float] = [] - - def record(state: ts.OptimState) -> None: - energy_history.append(float(state.energy[0])) - fmax_history.append(float(torch.linalg.norm(state.forces, dim=1).max())) - - initial_opt_state = ts.fire_init(state, model, fire_flavor="ase_fire") - record(initial_opt_state) - - def convergence_fn(state: ts.OptimState, last_energy: torch.Tensor) -> torch.Tensor: - del last_energy - record(state) - return ts.generate_force_convergence_fn(force_tol=fmax)(state, state.energy) - - result = ts.optimize( - state, - model, - optimizer=ts.Optimizer.fire, - convergence_fn=convergence_fn, - max_steps=steps, - steps_between_swaps=1, - autobatcher=False, - fire_flavor="ase_fire", - ) - return result, energy_history, fmax_history - - -def run_ase_fire( - atoms: Atoms, *, params: PotentialParams, steps: int, fmax: float -) -> tuple[Atoms, list[float], list[float]]: - atoms = atoms.copy() - atoms.calc = ASECalculator(params) - optimizer = FIRE(atoms, logfile=None) - energy_history: list[float] = [] - fmax_history: list[float] = [] - - def record() -> None: - energy_history.append(float(atoms.get_potential_energy())) - fmax_history.append(float(np.linalg.norm(atoms.get_forces(), axis=1).max())) - - optimizer.attach(record, interval=1) - optimizer.run(fmax=fmax, steps=steps) - return atoms, energy_history, fmax_history - - -params = PotentialParams() -steps = 80 -fmax = 0.03 -initial_position = (-0.2, 0.9, 0.0) -model = TorchModel(params) -state = make_state(initial_position) -atoms = Atoms("Ar", positions=[initial_position], cell=np.eye(3) * 10.0, pbc=False) - -ts_final, ts_energy, ts_force = run_torch_fire(state, model, steps=steps, fmax=fmax) -ase_final, ase_energy, ase_force = run_ase_fire( - atoms, params=params, steps=steps, fmax=fmax -) - -ts_position = ts_final.positions.detach().cpu().numpy()[0] -ase_position = ase_final.positions[0] -print(f"torch-sim steps: {len(ts_force)}") -print(f"ASE steps: {len(ase_force)}") -print(f"final position abs diff: {np.max(np.abs(ts_position - ase_position)):.3e}") -print(f"final energy abs diff: {abs(ts_energy[-1] - ase_energy[-1]):.3e}") -print(f"final fmax ts/ase: {ts_force[-1]:.3e} / {ase_force[-1]:.3e}") - -common_steps = min(len(ts_energy), len(ase_energy)) -step_axis = np.arange(common_steps) -energy_residual = np.array(ts_energy[:common_steps]) - np.array(ase_energy[:common_steps]) -force_residual = np.array(ts_force[:common_steps]) - np.array(ase_force[:common_steps]) - -fig, axes = plt.subplots(2, 2, figsize=(10, 7), sharex="col") -axes[0, 0].plot(ts_energy, label="torch-sim") -axes[0, 0].plot(ase_energy, "--", label="ASE") -axes[0, 0].set_ylabel("Energy") -axes[0, 0].set_title("Plain FIRE energy") -axes[0, 0].legend() - -axes[0, 1].plot(ts_force, label="torch-sim") -axes[0, 1].plot(ase_force, "--", label="ASE") -axes[0, 1].axhline(fmax, color="k", linestyle=":", label="fmax") -axes[0, 1].set_ylabel("Max force") -axes[0, 1].set_yscale("log") -axes[0, 1].set_title("Plain FIRE convergence") -axes[0, 1].legend() - -axes[1, 0].axhline(0.0, color="k", linewidth=0.8) -axes[1, 0].plot(step_axis, energy_residual) -axes[1, 0].set_xlabel("Optimization step") -axes[1, 0].set_ylabel("TS - ASE") -axes[1, 0].set_title("Energy residual") - -axes[1, 1].axhline(0.0, color="k", linewidth=0.8) -axes[1, 1].plot(step_axis, force_residual) -axes[1, 1].set_xlabel("Optimization step") -axes[1, 1].set_ylabel("TS - ASE") -axes[1, 1].set_title("Max-force residual") - -fig.tight_layout() -fig.savefig("fire_ase_torchsim_comparison.png", dpi=200) -print("Saved comparison plot to fire_ase_torchsim_comparison.png") diff --git a/examples/scripts/9_neb.py b/examples/scripts/9_neb.py deleted file mode 100644 index 5462e3acb..000000000 --- a/examples/scripts/9_neb.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Compare torch-sim and ASE Nudged Elastic Band trajectories.""" -# ruff: noqa: D101, D102, D103, D107 - -# %% -# /// script -# dependencies = [ -# "ase", -# "matplotlib", -# ] -# /// - -from dataclasses import dataclass -from typing import ClassVar - -import matplotlib.pyplot as plt -import numpy as np -import torch -from ase import Atoms -from ase.calculators.calculator import Calculator, all_changes -from ase.mep import NEB as ASENEB -from ase.optimize import FIRE - -import torch_sim as ts -from torch_sim.models.interface import ModelInterface -from torch_sim.optimizers import fire_init, fire_step -from torch_sim.workflows.neb import ( - as_sim_state, - assemble_path, - interpolate_path, - neb_convergence_fn, - neb_init, - neb_step, -) - - -@dataclass(frozen=True) -class CurvedDoubleWellParams: - valley_scale: float = 5.0 - valley_curve: float = 0.5 - - -def curved_double_well( - positions: torch.Tensor, params: CurvedDoubleWellParams -) -> tuple[torch.Tensor, torch.Tensor]: - """Return per-atom energies and forces for a curved double-well surface.""" - x = positions[:, 0] - y = positions[:, 1] - z = positions[:, 2] - u = x**2 - 1.0 - v = y - params.valley_curve * u - energy = u**2 + params.valley_scale * v**2 + z**2 - dE_dx = 4.0 * x * u - 4.0 * params.valley_scale * params.valley_curve * x * v - dE_dy = 2.0 * params.valley_scale * v - dE_dz = 2.0 * z - forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) - return energy, forces - - -class TorchCurvedDoubleWellModel(ModelInterface): - def __init__( - self, - *, - device: torch.device, - dtype: torch.dtype, - params: CurvedDoubleWellParams, - ) -> None: - super().__init__() - self._device = device - self._dtype = dtype - self._compute_forces = True - self._compute_stress = True - self.params = params - - def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: - del kwargs - per_atom_energy, forces = curved_double_well(state.positions, self.params) - energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) - energy.scatter_add_(0, state.system_idx, per_atom_energy) - return { - "energy": energy, - "forces": forces, - "stress": torch.zeros( - state.n_systems, 3, 3, device=state.device, dtype=state.dtype - ), - } - - -class ASECurvedDoubleWellCalculator(Calculator): - implemented_properties: ClassVar[list[str]] = ["energy", "forces"] - - def __init__(self, params: CurvedDoubleWellParams) -> None: - super().__init__() - self.params = params - - def calculate( - self, - atoms: Atoms | None = None, - properties: list[str] | None = None, - system_changes: list[str] = all_changes, - ) -> None: - super().calculate(atoms, properties, system_changes) - positions = torch.tensor(self.atoms.positions, dtype=torch.float64) - per_atom_energy, forces = curved_double_well(positions, self.params) - self.results["energy"] = float(per_atom_energy.sum().item()) - self.results["forces"] = forces.detach().cpu().numpy() - - -def make_state( - position: tuple[float, float, float], *, device: torch.device -) -> ts.SimState: - return ts.SimState( - positions=torch.tensor([position], device=device, dtype=torch.float64), - masses=torch.ones(1, device=device, dtype=torch.float64), - cell=torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * 10.0, - pbc=False, - atomic_numbers=torch.tensor([18], device=device), - system_idx=torch.zeros(1, device=device, dtype=torch.long), - ) - - -def relative_energies_torch(state: ts.SimState, model: ModelInterface) -> np.ndarray: - energies = model(state)["energy"].detach().cpu().numpy() - return energies - energies[0] - - -def run_torch_sim_neb( - initial_state: ts.SimState, - final_state: ts.SimState, - model: ModelInterface, - *, - n_images: int, - spring_constant: float, - max_steps: int, - fmax: float, -) -> tuple[ts.SimState, list[np.ndarray], list[float]]: - movable_images = interpolate_path(initial_state, final_state, n_images) - endpoint_output = model( - ts.concatenate_states([as_sim_state(initial_state), as_sim_state(final_state)]) - ) - endpoint_kwargs = { - "initial_state": as_sim_state(initial_state), - "final_state": as_sim_state(final_state), - "initial_energy": endpoint_output["energy"][0], - "final_energy": endpoint_output["energy"][1], - "spring_constant": spring_constant, - "use_climbing_image": True, - } - energy_history: list[np.ndarray] = [] - max_force_history: list[float] = [] - - def record(state: ts.SimState) -> None: - full_path = assemble_path(initial_state, state, final_state) - energy_history.append(relative_energies_torch(full_path, model)) - max_force_history.append(float(torch.linalg.norm(state.forces, dim=1).max())) - - def convergence(state: ts.OptimState, last_energy: torch.Tensor) -> torch.Tensor: - record(state) - return neb_convergence_fn(state, last_energy, fmax=fmax) - - initial_opt_state = neb_init( - movable_images, - model, - **endpoint_kwargs, - base_init_fn=fire_init, - base_init_kwargs={"fire_flavor": "ase_fire"}, - ) - record(initial_opt_state) - - final_movable = ts.optimize( - movable_images, - model, - optimizer=(neb_init, neb_step), - convergence_fn=convergence, - max_steps=max_steps, - steps_between_swaps=1, - autobatcher=False, - init_kwargs={ - **endpoint_kwargs, - "base_init_fn": fire_init, - "base_init_kwargs": {"fire_flavor": "ase_fire"}, - }, - **endpoint_kwargs, - base_step_fn=fire_step, - base_step_kwargs={"fire_flavor": "ase_fire"}, - ) - final_path = assemble_path(initial_state, final_movable, final_state) - return final_path, energy_history, max_force_history - - -def run_ase_neb( - initial_atoms: Atoms, - final_atoms: Atoms, - *, - params: CurvedDoubleWellParams, - n_images: int, - spring_constant: float, - max_steps: int, - fmax: float, -) -> tuple[list[Atoms], list[np.ndarray], list[float]]: - images = [initial_atoms.copy()] - images.extend(initial_atoms.copy() for _ in range(n_images)) - images.append(final_atoms.copy()) - for image in images: - image.calc = ASECurvedDoubleWellCalculator(params) - - neb = ASENEB(images, k=spring_constant, climb=True, method="improvedtangent") - neb.interpolate(mic=True) - optimizer = FIRE(neb, logfile=None) - energy_history: list[np.ndarray] = [] - max_force_history: list[float] = [] - - def record() -> None: - energies = np.array([image.get_potential_energy() for image in images]) - energy_history.append(energies - energies[0]) - forces = neb.get_forces().reshape(-1, 3) - max_force_history.append(float(np.linalg.norm(forces, axis=1).max())) - - optimizer.attach(record, interval=1) - optimizer.run(fmax=fmax, steps=max_steps) - return images, energy_history, max_force_history - - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -params = CurvedDoubleWellParams() -n_images = 7 -spring_constant = 0.1 -max_steps = 200 -fmax = 0.03 - -initial_state = make_state((-1.0, 0.0, 0.0), device=device) -final_state = make_state((1.0, 0.0, 0.0), device=device) -model = TorchCurvedDoubleWellModel(device=device, dtype=torch.float64, params=params) - -initial_atoms = Atoms("Ar", positions=[[-1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) -final_atoms = Atoms("Ar", positions=[[1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) - -torch_path, torch_energy_history, torch_fmax = run_torch_sim_neb( - initial_state, - final_state, - model, - n_images=n_images, - spring_constant=spring_constant, - max_steps=max_steps, - fmax=fmax, -) -ase_images, ase_energy_history, ase_fmax = run_ase_neb( - initial_atoms, - final_atoms, - params=params, - n_images=n_images, - spring_constant=spring_constant, - max_steps=max_steps, - fmax=fmax, -) - -torch_final = relative_energies_torch(torch_path, model) -ase_final = ase_energy_history[-1] -reaction_coordinate = np.linspace(0.0, 1.0, n_images + 2) - -print("Final relative energies (eV)") -print("image torch-sim ASE abs diff") -for idx, (torch_energy, ase_energy) in enumerate( - zip(torch_final, ase_final, strict=True) -): - print( - f"{idx:5d} {torch_energy: .8f} {ase_energy: .8f} " - f"{abs(torch_energy - ase_energy):.3e}" - ) -print(f"Barrier difference: {abs(torch_final.max() - ase_final.max()):.3e} eV") - -common_steps = min(len(torch_fmax), len(ase_fmax)) -step_axis = np.arange(common_steps) -final_energy_residual = torch_final - ase_final -force_residual = np.array(torch_fmax[:common_steps]) - np.array(ase_fmax[:common_steps]) - -fig, axes = plt.subplots(2, 2, figsize=(10, 7)) -axes[0, 0].plot(reaction_coordinate, torch_final, "o-", label="torch-sim") -axes[0, 0].plot(reaction_coordinate, ase_final, "s--", label="ASE") -axes[0, 0].set_ylabel("Relative energy") -axes[0, 0].set_title("Final NEB profile") -axes[0, 0].legend() - -axes[0, 1].plot(torch_fmax, label="torch-sim") -axes[0, 1].plot(ase_fmax, label="ASE") -axes[0, 1].axhline(fmax, color="k", linestyle=":", label="fmax") -axes[0, 1].set_ylabel("Max NEB force") -axes[0, 1].set_yscale("log") -axes[0, 1].set_title("Convergence") -axes[0, 1].legend() - -axes[1, 0].axhline(0.0, color="k", linewidth=0.8) -axes[1, 0].plot(reaction_coordinate, final_energy_residual, "o-") -axes[1, 0].set_xlabel("Reaction coordinate") -axes[1, 0].set_ylabel("TS - ASE") -axes[1, 0].set_title("Final energy residual") - -axes[1, 1].axhline(0.0, color="k", linewidth=0.8) -axes[1, 1].plot(step_axis, force_residual) -axes[1, 1].set_xlabel("Optimization step") -axes[1, 1].set_ylabel("TS - ASE") -axes[1, 1].set_title("Max-force residual") - -fig.tight_layout() -fig.savefig("neb_ase_torchsim_comparison.png", dpi=200) -print("Saved comparison plot to neb_ase_torchsim_comparison.png") diff --git a/examples/tutorials/nudged_elastic_band.py b/examples/tutorials/nudged_elastic_band.py new file mode 100644 index 000000000..26dcbff5f --- /dev/null +++ b/examples/tutorials/nudged_elastic_band.py @@ -0,0 +1,573 @@ +"""Tutorial: run batched torch-sim Nudged Elastic Band calculations.""" +# ruff: noqa: D101, D102, D103, D107 + +# %% +# /// script +# dependencies = [ +# "ase", +# "matplotlib", +# ] +# /// + +# %% +from dataclasses import dataclass +from time import perf_counter +from typing import ClassVar + +import matplotlib.pyplot as plt +import numpy as np +import torch +from ase import Atoms +from ase.calculators.calculator import Calculator, all_changes +from ase.mep import NEB as ASENEB +from ase.optimize import FIRE + +import torch_sim as ts +from torch_sim.models.interface import ModelInterface +from torch_sim.optimizers import fire_init, fire_step +from torch_sim.workflows.neb import ( + as_sim_state, + assemble_path, + calculate_neb_forces, + interpolate_path, +) + +# %% [markdown] +""" +# Nudged Elastic Band + +The nudged elastic band (NEB) method finds a minimum-energy pathway between two +endpoints. The useful TorchSim pattern is not just "run one NEB", but "run many +related NEBs together": each path is an optimizer group inside one batched +`SimState`, so model evaluations and optimizer bookkeeping happen together. + +ASE is kept in this tutorial as a reference implementation. It makes the result +easier to trust, and it gives a clear baseline for reporting the speedup from +batching several independent paths in TorchSim. +""" + + +@dataclass(frozen=True) +class NEBCase: + name: str + valley_scale: float + valley_curve: float + + +def curved_double_well( + positions: torch.Tensor, + valley_scale: torch.Tensor | float, + valley_curve: torch.Tensor | float, +) -> tuple[torch.Tensor, torch.Tensor]: + x = positions[:, 0] + y = positions[:, 1] + z = positions[:, 2] + u = x**2 - 1.0 + v = y - valley_curve * u + energy = u**2 + valley_scale * v**2 + z**2 + dE_dx = 4.0 * x * u - 4.0 * valley_scale * valley_curve * x * v + dE_dy = 2.0 * valley_scale * v + dE_dz = 2.0 * z + forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) + return energy, forces + + +# %% [markdown] +""" +The analytic surface below keeps the tutorial fast and reproducible. Every case +has minima at `x = -1` and `x = 1`, but each one has a different valley shape. +That gives us several independent NEB calculations with different barriers and +curvatures. +""" + + +class TorchBatchedDoubleWellModel(ModelInterface): + def __init__( + self, + cases: list[NEBCase], + *, + device: torch.device, + dtype: torch.dtype, + ) -> None: + super().__init__() + self._device = device + self._dtype = dtype + self._compute_forces = True + self._compute_stress = True + self.cases = cases + self.valley_scale = torch.tensor( + [case.valley_scale for case in cases], device=device, dtype=dtype + ) + self.valley_curve = torch.tensor( + [case.valley_curve for case in cases], device=device, dtype=dtype + ) + + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + del kwargs + case_idx = state.group_idx[state.system_idx] + per_atom_energy, forces = curved_double_well( + state.positions, self.valley_scale[case_idx], self.valley_curve[case_idx] + ) + energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) + energy.scatter_add_(0, state.system_idx, per_atom_energy) + return { + "energy": energy, + "forces": forces, + "stress": torch.zeros( + state.n_systems, 3, 3, device=state.device, dtype=state.dtype + ), + } + + +class ASEDoubleWellCalculator(Calculator): + implemented_properties: ClassVar[list[str]] = ["energy", "forces"] + + def __init__(self, case: NEBCase) -> None: + super().__init__() + self.case = case + + def calculate( + self, + atoms: Atoms | None = None, + properties: list[str] | None = None, + system_changes: list[str] = all_changes, + ) -> None: + super().calculate(atoms, properties, system_changes) + positions = torch.tensor(self.atoms.positions, dtype=torch.float64) + per_atom_energy, forces = curved_double_well( + positions, self.case.valley_scale, self.case.valley_curve + ) + self.results["energy"] = float(per_atom_energy.sum().item()) + self.results["forces"] = forces.detach().cpu().numpy() + + +# %% [markdown] +""" +The same potential is exposed through two interfaces. TorchSim gets a vectorized +model that chooses the right parameters from `group_idx`; ASE gets one calculator +per case. In a real workflow, the TorchSim model would usually be an ML +interatomic potential and the groups would be different reactions, defects, or +starting guesses. +""" + + +def make_state(position: tuple[float, float, float], device: torch.device) -> ts.SimState: + return ts.SimState( + positions=torch.tensor([position], device=device, dtype=torch.float64), + masses=torch.ones(1, device=device, dtype=torch.float64), + cell=torch.eye(3, device=device, dtype=torch.float64).unsqueeze(0) * 10.0, + pbc=False, + atomic_numbers=torch.tensor([18], device=device), + system_idx=torch.zeros(1, device=device, dtype=torch.long), + ) + + +def make_endpoint_batch( + cases: list[NEBCase], + position: tuple[float, float, float], + device: torch.device, +) -> ts.SimState: + states = [make_state(position, device) for _ in cases] + return ts.concatenate_states(states) + + +def state_for_group(state: ts.SimState, group_idx: int) -> ts.SimState: + system_indices = torch.where(state.group_idx == group_idx)[0] + return state[system_indices] + + +def interpolate_batched_paths( + initial_state: ts.SimState, + final_state: ts.SimState, + n_images: int, +) -> ts.SimState: + paths = [ + interpolate_path( + state_for_group(initial_state, group_idx), + state_for_group(final_state, group_idx), + n_images, + ) + for group_idx in range(initial_state.n_groups) + ] + return ts.concatenate_states(paths) + + +def assemble_batched_paths( + initial_state: ts.SimState, + movable_state: ts.SimState, + final_state: ts.SimState, +) -> ts.SimState: + paths = [] + for group_idx in range(initial_state.n_groups): + path = assemble_path( + state_for_group(initial_state, group_idx), + state_for_group(movable_state, group_idx), + state_for_group(final_state, group_idx), + ) + path.group_idx = torch.zeros(path.n_systems, device=path.device, dtype=torch.long) + paths.append(path) + return ts.concatenate_states(paths) + + +def relative_energies_by_group( + state: ts.SimState, + model: ModelInterface, + n_groups: int, +) -> np.ndarray: + energies = model(state)["energy"].detach().cpu().numpy() + profiles = energies.reshape(n_groups, -1) + return profiles - profiles[:, :1] + + +def store_batched_neb_forces( + state: ts.OptimState, + neb_forces: torch.Tensor, +) -> None: + max_force_by_group = torch.zeros( + state.n_groups, device=state.device, dtype=state.dtype + ) + atom_group_idx = state.group_idx[state.system_idx] + for group_idx in range(state.n_groups): + group_mask = atom_group_idx == group_idx + max_force_by_group[group_idx] = torch.linalg.norm( + neb_forces[group_mask], dim=1 + ).max() + state.forces = neb_forces + state.neb_forces = neb_forces + state.neb_max_force_by_group = max_force_by_group + state.neb_max_force = max_force_by_group.max() + + +def calculate_batched_neb_forces( + state: ts.SimState, + true_forces: torch.Tensor, + true_energies: torch.Tensor, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + *, + spring_constant: float, + use_climbing_image: bool, +) -> torch.Tensor: + neb_forces = torch.zeros_like(true_forces) + atom_group_idx = state.group_idx[state.system_idx] + for group_idx in range(initial_state.n_groups): + system_indices = torch.where(state.group_idx == group_idx)[0] + atom_mask = atom_group_idx == group_idx + path_state = assemble_path( + state_for_group(initial_state, group_idx), + state[system_indices], + state_for_group(final_state, group_idx), + ) + neb_forces[atom_mask] = calculate_neb_forces( + path_state, + true_forces[atom_mask], + true_energies[system_indices], + initial_energies[group_idx], + final_energies[group_idx], + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + return neb_forces + + +def batched_neb_init( + state: ts.SimState, + model: ModelInterface, + *, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + spring_constant: float, + use_climbing_image: bool, +) -> ts.OptimState: + opt_state = fire_init(state, model, fire_flavor="ase_fire") + neb_forces = calculate_batched_neb_forces( + opt_state, + opt_state.forces, + opt_state.energy, + initial_state, + final_state, + initial_energies, + final_energies, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + store_batched_neb_forces(opt_state, neb_forces) + return opt_state + + +def batched_neb_step( + state: ts.OptimState, + model: ModelInterface, + *, + initial_state: ts.SimState, + final_state: ts.SimState, + initial_energies: torch.Tensor, + final_energies: torch.Tensor, + spring_constant: float, + use_climbing_image: bool, +) -> ts.OptimState: + state = fire_step(state, model, fire_flavor="ase_fire") + true_forces = state.forces.clone() + neb_forces = calculate_batched_neb_forces( + state, + true_forces, + state.energy, + initial_state, + final_state, + initial_energies, + final_energies, + spring_constant=spring_constant, + use_climbing_image=use_climbing_image, + ) + state.true_forces = true_forces + store_batched_neb_forces(state, neb_forces) + return state + + +def run_torch_sim_batched_neb( + initial_state: ts.SimState, + final_state: ts.SimState, + model: ModelInterface, + *, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[ts.SimState, list[np.ndarray], list[np.ndarray], float]: + start_time = perf_counter() + movable_images = interpolate_batched_paths(initial_state, final_state, n_images) + initial_energies = model(initial_state)["energy"] + final_energies = model(final_state)["energy"] + energy_history: list[np.ndarray] = [] + max_force_history: list[np.ndarray] = [] + + state = batched_neb_init( + movable_images, + model, + initial_state=as_sim_state(initial_state), + final_state=as_sim_state(final_state), + initial_energies=initial_energies, + final_energies=final_energies, + spring_constant=spring_constant, + use_climbing_image=True, + ) + + def record(current_state: ts.OptimState) -> None: + full_path = assemble_batched_paths(initial_state, current_state, final_state) + energy_history.append( + relative_energies_by_group(full_path, model, initial_state.n_groups) + ) + max_force_history.append( + current_state.neb_max_force_by_group.detach().cpu().numpy() + ) + + record(state) + for _ in range(max_steps): + state = batched_neb_step( + state, + model, + initial_state=as_sim_state(initial_state), + final_state=as_sim_state(final_state), + initial_energies=initial_energies, + final_energies=final_energies, + spring_constant=spring_constant, + use_climbing_image=True, + ) + record(state) + if bool((state.neb_max_force_by_group < fmax).all()): + break + + elapsed = perf_counter() - start_time + final_path = assemble_batched_paths(initial_state, state, final_state) + return final_path, energy_history, max_force_history, elapsed + + +def run_ase_neb( + initial_atoms: Atoms, + final_atoms: Atoms, + *, + case: NEBCase, + n_images: int, + spring_constant: float, + max_steps: int, + fmax: float, +) -> tuple[list[Atoms], list[np.ndarray], list[float], float]: + start_time = perf_counter() + images = [initial_atoms.copy()] + images.extend(initial_atoms.copy() for _ in range(n_images)) + images.append(final_atoms.copy()) + for image in images: + image.calc = ASEDoubleWellCalculator(case) + + neb = ASENEB(images, k=spring_constant, climb=True, method="improvedtangent") + neb.interpolate(mic=True) + optimizer = FIRE(neb, logfile=None) + energy_history: list[np.ndarray] = [] + max_force_history: list[float] = [] + + def record() -> None: + energies = np.array([image.get_potential_energy() for image in images]) + energy_history.append(energies - energies[0]) + forces = neb.get_forces().reshape(-1, 3) + max_force_history.append(float(np.linalg.norm(forces, axis=1).max())) + + optimizer.attach(record, interval=1) + optimizer.run(fmax=fmax, steps=max_steps) + elapsed = perf_counter() - start_time + return images, energy_history, max_force_history, elapsed + + +# %% [markdown] +""" +Now set up a small batch. Each case is one independent NEB calculation with the +same endpoints and a different curved valley. TorchSim will optimize all of +these paths together; ASE will run the same cases sequentially. +""" + +# %% +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +cases = [ + NEBCase("reference", valley_scale=5.0, valley_curve=0.50), + NEBCase("shallow", valley_scale=3.0, valley_curve=0.25), + NEBCase("steep", valley_scale=8.0, valley_curve=0.35), + NEBCase("left bend", valley_scale=5.5, valley_curve=-0.45), + NEBCase("tight bend", valley_scale=7.0, valley_curve=0.70), + NEBCase("wide bend", valley_scale=4.0, valley_curve=-0.75), + NEBCase("soft", valley_scale=2.5, valley_curve=0.60), + NEBCase("stiff", valley_scale=9.0, valley_curve=-0.25), +] +n_images = 7 +spring_constant = 0.1 +max_steps = 200 +fmax = 0.03 + +initial_state = make_endpoint_batch(cases, (-1.0, 0.0, 0.0), device) +final_state = make_endpoint_batch(cases, (1.0, 0.0, 0.0), device) +model = TorchBatchedDoubleWellModel(cases, device=device, dtype=torch.float64) + +initial_atoms = Atoms("Ar", positions=[[-1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) +final_atoms = Atoms("Ar", positions=[[1.0, 0.0, 0.0]], cell=np.eye(3) * 10.0) + + +# %% [markdown] +""" +The TorchSim call below is the main workflow. The movable images for every case +are stored in one state, and FIRE uses `group_idx` so each NEB keeps its own +optimizer state while still sharing one batched model call per step. +""" + +# %% +torch_path, torch_energy_history, torch_fmax, torch_elapsed = run_torch_sim_batched_neb( + initial_state, + final_state, + model, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, +) + + +# %% [markdown] +""" +ASE is the validation baseline. It does not batch these independent NEBs here, +so we run the same cases one after another and compare final profiles. +""" + +# %% +ase_energy_history: list[list[np.ndarray]] = [] +ase_fmax: list[list[float]] = [] +ase_elapsed_by_case: list[float] = [] +ase_start = perf_counter() +for case in cases: + _, case_energy_history, case_fmax, case_elapsed = run_ase_neb( + initial_atoms, + final_atoms, + case=case, + n_images=n_images, + spring_constant=spring_constant, + max_steps=max_steps, + fmax=fmax, + ) + ase_energy_history.append(case_energy_history) + ase_fmax.append(case_fmax) + ase_elapsed_by_case.append(case_elapsed) +ase_elapsed = perf_counter() - ase_start + + +# %% [markdown] +""" +Finally, compare accuracy and runtime. The per-case barrier differences should +be tiny; once that is true, the runtime comparison shows why batching many NEBs +together is the more relevant TorchSim workflow. +""" + +# %% +n_cases = len(cases) +reaction_coordinate = np.linspace(0.0, 1.0, n_images + 2) +torch_final = relative_energies_by_group(torch_path, model, n_cases) +ase_final = np.stack([history[-1] for history in ase_energy_history]) +barrier_difference = np.abs(torch_final.max(axis=1) - ase_final.max(axis=1)) +max_barrier_difference = barrier_difference.max() +speedup = ase_elapsed / torch_elapsed if torch_elapsed > 0 else float("inf") + +print("Final barrier comparison") +print("case torch-sim ASE abs diff") +for case, torch_profile, ase_profile, diff in zip( + cases, torch_final, ase_final, barrier_difference, strict=True +): + print( + f"{case.name:10s} {torch_profile.max(): .8f} {ase_profile.max(): .8f} " + f"{diff:.3e}" + ) +print(f"Max barrier difference: {max_barrier_difference:.3e} eV") +print(f"torch-sim batched runtime: {torch_elapsed:.3f} s") +print(f"ASE sequential runtime: {ase_elapsed:.3f} s") +print(f"torch-sim speedup vs sequential ASE: {speedup:.2f}x") + + +# %% +fig, axes = plt.subplots(2, 2, figsize=(11, 8)) +colors = plt.cm.viridis(np.linspace(0.05, 0.95, n_cases)) + +for idx, (case, color) in enumerate(zip(cases, colors, strict=True)): + axes[0, 0].plot( + reaction_coordinate, + torch_final[idx], + color=color, + linewidth=2, + label=case.name, + ) + axes[0, 0].plot(reaction_coordinate, ase_final[idx], "o", color=color, markersize=3) +axes[0, 0].set_ylabel("Relative energy") +axes[0, 0].set_title("Final profiles: lines are torch-sim, dots are ASE") +axes[0, 0].legend(fontsize=8) + +axes[0, 1].bar([case.name for case in cases], barrier_difference) +axes[0, 1].set_ylabel("Barrier |torch-sim - ASE|") +axes[0, 1].set_yscale("log") +axes[0, 1].tick_params(axis="x", rotation=45) +axes[0, 1].set_title("Validation error by NEB case") + +torch_force_history = np.stack(torch_fmax) +axes[1, 0].plot(torch_force_history.max(axis=1), label="torch-sim batch max") +for case, case_fmax in zip(cases, ase_fmax, strict=True): + axes[1, 0].plot(case_fmax, alpha=0.35, linewidth=1, label=f"ASE {case.name}") +axes[1, 0].axhline(fmax, color="k", linestyle=":", label="fmax") +axes[1, 0].set_xlabel("Optimization step") +axes[1, 0].set_ylabel("Max NEB force") +axes[1, 0].set_yscale("log") +axes[1, 0].set_title("Convergence") +axes[1, 0].legend(fontsize=7) + +axes[1, 1].bar(["torch-sim batched", "ASE sequential"], [torch_elapsed, ase_elapsed]) +axes[1, 1].set_ylabel("Runtime (s)") +axes[1, 1].set_title(f"Batching speedup: {speedup:.2f}x") + +fig.tight_layout() +fig.savefig("neb_ase_torchsim_comparison.png", dpi=200) +print("Saved comparison plot to neb_ase_torchsim_comparison.png") From 6e4617361623c7c1796b32e5457e399967a727d6 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 27 May 2026 15:25:31 -0400 Subject: [PATCH 6/7] clean --- examples/tutorials/nudged_elastic_band.py | 35 ++++++++++++------- tests/test_autobatching.py | 29 ++++++++++++++++ tests/workflows/test_neb.py | 42 ++++++++++++++++++++++- torch_sim/autobatching.py | 3 +- torch_sim/workflows/neb.py | 37 ++++++-------------- 5 files changed, 105 insertions(+), 41 deletions(-) diff --git a/examples/tutorials/nudged_elastic_band.py b/examples/tutorials/nudged_elastic_band.py index 26dcbff5f..86546e88b 100644 --- a/examples/tutorials/nudged_elastic_band.py +++ b/examples/tutorials/nudged_elastic_band.py @@ -50,12 +50,14 @@ @dataclass(frozen=True) class NEBCase: name: str + barrier_height: float valley_scale: float valley_curve: float def curved_double_well( positions: torch.Tensor, + barrier_height: torch.Tensor | float, valley_scale: torch.Tensor | float, valley_curve: torch.Tensor | float, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -64,8 +66,8 @@ def curved_double_well( z = positions[:, 2] u = x**2 - 1.0 v = y - valley_curve * u - energy = u**2 + valley_scale * v**2 + z**2 - dE_dx = 4.0 * x * u - 4.0 * valley_scale * valley_curve * x * v + energy = barrier_height * u**2 + valley_scale * v**2 + z**2 + dE_dx = 4.0 * barrier_height * x * u - 4.0 * valley_scale * valley_curve * x * v dE_dy = 2.0 * valley_scale * v dE_dz = 2.0 * z forces = -torch.stack([dE_dx, dE_dy, dE_dz], dim=1) @@ -101,12 +103,18 @@ def __init__( self.valley_curve = torch.tensor( [case.valley_curve for case in cases], device=device, dtype=dtype ) + self.barrier_height = torch.tensor( + [case.barrier_height for case in cases], device=device, dtype=dtype + ) def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: del kwargs case_idx = state.group_idx[state.system_idx] per_atom_energy, forces = curved_double_well( - state.positions, self.valley_scale[case_idx], self.valley_curve[case_idx] + state.positions, + self.barrier_height[case_idx], + self.valley_scale[case_idx], + self.valley_curve[case_idx], ) energy = torch.zeros(state.n_systems, device=state.device, dtype=state.dtype) energy.scatter_add_(0, state.system_idx, per_atom_energy) @@ -135,7 +143,10 @@ def calculate( super().calculate(atoms, properties, system_changes) positions = torch.tensor(self.atoms.positions, dtype=torch.float64) per_atom_energy, forces = curved_double_well( - positions, self.case.valley_scale, self.case.valley_curve + positions, + self.case.barrier_height, + self.case.valley_scale, + self.case.valley_curve, ) self.results["energy"] = float(per_atom_energy.sum().item()) self.results["forces"] = forces.detach().cpu().numpy() @@ -431,14 +442,14 @@ def record() -> None: # %% device = torch.device("cuda" if torch.cuda.is_available() else "cpu") cases = [ - NEBCase("reference", valley_scale=5.0, valley_curve=0.50), - NEBCase("shallow", valley_scale=3.0, valley_curve=0.25), - NEBCase("steep", valley_scale=8.0, valley_curve=0.35), - NEBCase("left bend", valley_scale=5.5, valley_curve=-0.45), - NEBCase("tight bend", valley_scale=7.0, valley_curve=0.70), - NEBCase("wide bend", valley_scale=4.0, valley_curve=-0.75), - NEBCase("soft", valley_scale=2.5, valley_curve=0.60), - NEBCase("stiff", valley_scale=9.0, valley_curve=-0.25), + NEBCase("0.8 eV", barrier_height=0.8, valley_scale=5.0, valley_curve=0.50), + NEBCase("0.9 eV", barrier_height=0.9, valley_scale=3.0, valley_curve=0.25), + NEBCase("1.0 eV", barrier_height=1.0, valley_scale=8.0, valley_curve=0.35), + NEBCase("1.1 eV", barrier_height=1.1, valley_scale=5.5, valley_curve=-0.45), + NEBCase("1.2 eV", barrier_height=1.2, valley_scale=7.0, valley_curve=0.70), + NEBCase("1.3 eV", barrier_height=1.3, valley_scale=4.0, valley_curve=-0.75), + NEBCase("1.4 eV", barrier_height=1.4, valley_scale=2.5, valley_curve=0.60), + NEBCase("1.5 eV", barrier_height=1.5, valley_scale=9.0, valley_curve=-0.25), ] n_images = 7 spring_constant = 0.1 diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a05e06089..4890e4e6a 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -715,6 +715,35 @@ def test_in_flight_max_iterations( assert batcher.iteration_count[idx] == max_iterations +def test_in_flight_max_iterations_completes_whole_group( + si_double_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + grouped_state = si_double_sim_state.clone() + grouped_state.group_idx = torch.zeros( + grouped_state.n_systems, device=grouped_state.device, dtype=torch.long + ) + batcher = InFlightAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=800.0, + max_iterations=1, + ) + batcher.load_states(grouped_state) + + state, [] = batcher.next_batch(None, None) + assert state is not None + assert state.n_systems == grouped_state.n_systems + assert state.n_groups == 1 + + convergence_tensor = torch.zeros(state.n_systems, dtype=torch.bool) + next_state, completed_states = batcher.next_batch(state, convergence_tensor) + + assert next_state is None + assert len(completed_states) == 1 + assert completed_states[0].n_systems == grouped_state.n_systems + + @pytest.mark.parametrize( "num_steps_per_batch", [ diff --git a/tests/workflows/test_neb.py b/tests/workflows/test_neb.py index e7db3e24e..bb3a4277f 100644 --- a/tests/workflows/test_neb.py +++ b/tests/workflows/test_neb.py @@ -7,7 +7,12 @@ import torch_sim as ts from tests.conftest import DEVICE, DTYPE from torch_sim.models.interface import ModelInterface -from torch_sim.workflows.neb import NEB, calculate_neb_forces, interpolate_path +from torch_sim.workflows.neb import ( + NEB, + assemble_path, + calculate_neb_forces, + interpolate_path, +) class HarmonicModel(ModelInterface): @@ -32,6 +37,13 @@ def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tenso } +class GroupIndexedModel(HarmonicModel): + def forward(self, state: ts.SimState, **kwargs: object) -> dict[str, torch.Tensor]: + if state.group_idx.max() >= 1: + raise AssertionError("NEB endpoint/path assembly should preserve one group.") + return super().forward(state, **kwargs) + + def _single_atom_state(position: float) -> ts.SimState: return ts.SimState( positions=torch.tensor([[position, 0.0, 0.0]], device=DEVICE, dtype=DTYPE), @@ -43,6 +55,18 @@ def _single_atom_state(position: float) -> ts.SimState: ) +def test_assemble_path_preserves_one_optimizer_group() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + movable = interpolate_path(initial, final, n_images=3) + + path = assemble_path(initial, movable, final) + + assert path.n_systems == 5 + assert path.n_groups == 1 + assert torch.equal(path.group_idx, torch.zeros(5, device=DEVICE, dtype=torch.long)) + + def test_interpolate_path_uses_movable_images_only() -> None: initial = _single_atom_state(0.0) final = _single_atom_state(1.0) @@ -159,3 +183,19 @@ def test_neb_run_uses_single_chain_optimize_without_moving_endpoints() -> None: result.positions[:, 0], torch.tensor([0.0, 0.5, 1.0], device=DEVICE, dtype=DTYPE), ) + assert result.n_groups == 1 + + +def test_neb_run_does_not_offset_endpoint_groups() -> None: + initial = _single_atom_state(0.0) + final = _single_atom_state(1.0) + neb = NEB( + model=GroupIndexedModel(), + n_images=1, + optimizer_type="gd", + optimizer_params={"pos_lr": 0.1}, + ) + + result = neb.run(initial, final, max_steps=1, fmax=1e-12) + + assert result.n_groups == 1 diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 121683374..016ef6511 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -1116,8 +1116,7 @@ def next_batch( # noqa: C901 if self.max_iterations is not None and ( self.iteration_count[abs_idx] >= self.max_iterations ): - # Force convergence for states that have reached max attempts - convergence_tensor[cur_idx] = torch.tensor(True) # noqa: FBT003 + convergence_tensor[updated_state.group_idx == cur_idx] = True completed_idx = [] completed_system_indices = [] diff --git a/torch_sim/workflows/neb.py b/torch_sim/workflows/neb.py index b4007a2f8..c74f72aa1 100644 --- a/torch_sim/workflows/neb.py +++ b/torch_sim/workflows/neb.py @@ -130,29 +130,22 @@ def interpolate_path( def as_sim_state(state: SimState) -> SimState: """Drop optimizer-only fields while preserving the atomistic state.""" - return SimState( - positions=state.positions, - masses=state.masses, - cell=state.cell, - pbc=state.pbc, - atomic_numbers=state.atomic_numbers, - system_idx=state.system_idx, - group_idx=state.group_idx, - _constraints=state.constraints, - ) + return SimState.from_state(state) def assemble_path( initial_state: SimState, movable_state: SimState, final_state: SimState ) -> SimState: """Return the full NEB path as endpoints plus movable images.""" - return concatenate_states( + path = concatenate_states( [ as_sim_state(initial_state), as_sim_state(movable_state), as_sim_state(final_state), ] ) + path.group_idx = torch.zeros(path.n_systems, device=path.device, dtype=torch.long) + return path def compute_tangents( @@ -220,6 +213,8 @@ def calculate_neb_forces( n_intermediate = n_total_images - 2 if n_intermediate <= 0: raise ValueError("A NEB path must include at least one movable image.") + if path_state.n_atoms % n_total_images != 0: + raise ValueError("NEB path images must contain the same number of atoms.") if true_energies.shape[0] != n_intermediate: raise ValueError(f"{true_energies.shape[0]=} does not match {n_intermediate=}.") @@ -262,10 +257,10 @@ def calculate_neb_forces( def _endpoint_energies( initial_state: SimState, final_state: SimState, model: ModelInterface ) -> tuple[torch.Tensor, torch.Tensor]: - output = model( - concatenate_states([as_sim_state(initial_state), as_sim_state(final_state)]) + return ( + model(as_sim_state(initial_state))["energy"][0], + model(as_sim_state(final_state))["energy"][0], ) - return output["energy"][0], output["energy"][1] def _store_neb_force_metadata(state: OptimState, neb_forces: torch.Tensor) -> None: @@ -339,18 +334,8 @@ def neb_convergence_fn( ) -> torch.Tensor: """Return all-or-nothing NEB convergence for the movable images.""" del last_energy - neb_max_force = getattr( - state, - "neb_max_force", - torch.linalg.norm(state.forces, dim=-1).max(), - ) - converged = neb_max_force < fmax - return torch.full( - (state.n_systems,), - bool(converged.item()), - device=state.device, - dtype=torch.bool, - ) + converged = torch.linalg.norm(state.forces, dim=-1).max() < fmax + return converged.expand(state.n_systems) @dataclass From 3fc3f6a10814222784e80cb8f0747a27f4f3268c Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Wed, 27 May 2026 16:49:06 -0400 Subject: [PATCH 7/7] fix docs --- examples/tutorials/nudged_elastic_band.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/tutorials/nudged_elastic_band.py b/examples/tutorials/nudged_elastic_band.py index 86546e88b..544104d43 100644 --- a/examples/tutorials/nudged_elastic_band.py +++ b/examples/tutorials/nudged_elastic_band.py @@ -47,6 +47,7 @@ """ +# %% @dataclass(frozen=True) class NEBCase: name: str @@ -83,6 +84,7 @@ def curved_double_well( """ +# %% class TorchBatchedDoubleWellModel(ModelInterface): def __init__( self, @@ -162,6 +164,7 @@ def calculate( """ +# %% def make_state(position: tuple[float, float, float], device: torch.device) -> ts.SimState: return ts.SimState( positions=torch.tensor([position], device=device, dtype=torch.float64), @@ -565,7 +568,13 @@ def record() -> None: axes[0, 1].set_title("Validation error by NEB case") torch_force_history = np.stack(torch_fmax) -axes[1, 0].plot(torch_force_history.max(axis=1), label="torch-sim batch max") +axes[1, 0].plot( + torch_force_history.max(axis=1), + color="tab:blue", + linestyle="--", + linewidth=2, + label="torch-sim batch max", +) for case, case_fmax in zip(cases, ase_fmax, strict=True): axes[1, 0].plot(case_fmax, alpha=0.35, linewidth=1, label=f"ASE {case.name}") axes[1, 0].axhline(fmax, color="k", linestyle=":", label="fmax")