From 8062bc5a4ee0fb667454f0886635db44d1df5f87 Mon Sep 17 00:00:00 2001 From: Rhys Goodall Date: Fri, 22 May 2026 09:16:40 -0400 Subject: [PATCH] fea: use float cutoffs --- examples/benchmarking/neighborlists.py | 4 +- examples/scripts/7_others.py | 2 +- tests/models/test_soft_sphere.py | 4 +- tests/test_neighbors.py | 42 ++++++++++------ torch_sim/autobatching.py | 3 +- torch_sim/models/dispersion.py | 16 +++--- torch_sim/models/electrostatics.py | 69 ++++++++++++++------------ torch_sim/models/pair_potential.py | 6 +-- torch_sim/neighbors/__init__.py | 4 +- torch_sim/neighbors/alchemiops.py | 14 +++--- torch_sim/neighbors/torch_nl.py | 16 +++--- torch_sim/neighbors/vesin.py | 29 ++++++++--- torch_sim/transforms.py | 30 +++++++++++ 13 files changed, 149 insertions(+), 90 deletions(-) diff --git a/examples/benchmarking/neighborlists.py b/examples/benchmarking/neighborlists.py index ce709ac18..780863d73 100644 --- a/examples/benchmarking/neighborlists.py +++ b/examples/benchmarking/neighborlists.py @@ -238,7 +238,7 @@ def _benchmark_backend( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, n_repeats: int, device: str, @@ -291,7 +291,7 @@ def run_benchmark(args: argparse.Namespace) -> dict[str, Any]: positions, cell, pbc, system_idx = _build_tensors( structures, dtype=torch_dtype, device=args.device ) - cutoff = torch.tensor(args.cutoff, dtype=torch_dtype, device=args.device) + cutoff = float(args.cutoff) n_atoms = positions.shape[0] backends = args.nl_backend if isinstance(args.nl_backend, list) else [args.nl_backend] diff --git a/examples/scripts/7_others.py b/examples/scripts/7_others.py index 972a3edf1..18cc629e7 100644 --- a/examples/scripts/7_others.py +++ b/examples/scripts/7_others.py @@ -53,7 +53,7 @@ pos, cell, pbc = state.positions, state.cell, state.pbc system_idx = state.system_idx n_atoms = state.n_atoms -cutoff = torch.tensor(4.0, dtype=pos.dtype) +cutoff = 4.0 self_interaction = False # Ensure pbc has the correct shape [n_systems, 3] diff --git a/tests/models/test_soft_sphere.py b/tests/models/test_soft_sphere.py index daa32102f..55a5766ac 100644 --- a/tests/models/test_soft_sphere.py +++ b/tests/models/test_soft_sphere.py @@ -151,7 +151,7 @@ def test_multispecies_initialization_custom() -> None: assert torch.allclose(model.sigma_matrix, sigma_matrix) assert torch.allclose(model.epsilon_matrix, epsilon_matrix) assert torch.allclose(model.alpha_matrix, alpha_matrix) - assert model.cutoff.item() == 3.0 + assert model.cutoff == 3.0 def test_multispecies_matrix_validation() -> None: @@ -196,7 +196,7 @@ def test_multispecies_cutoff_default() -> None: model = SoftSphereMultiModel( atomic_numbers=torch.tensor([0, 1, 2]), sigma_matrix=sigma_matrix ) - assert model.cutoff.item() == 3.0 + assert model.cutoff == 3.0 def test_multispecies_evaluation() -> None: diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index ec6a8bb52..5e775c53c 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -257,9 +257,8 @@ def test_neighbor_list_invariant_under_lattice_image_shifts( assert not torch.allclose(pos_shifted, pos_wrapped, rtol=0.0, atol=1e-12), ( "expected non-trivial lattice shifts along periodic axes" ) - c_tensor = torch.tensor(cutoff, dtype=DTYPE, device=DEVICE) map_w, sys_w, sh_w = nl_implementation( - cutoff=c_tensor, + cutoff=cutoff, positions=pos_wrapped, cell=cell_b, pbc=pbc_b, @@ -267,7 +266,7 @@ def test_neighbor_list_invariant_under_lattice_image_shifts( self_interaction=self_interaction, ) map_s, sys_s, sh_s = nl_implementation( - cutoff=c_tensor, + cutoff=cutoff, positions=pos_shifted, cell=cell_b, pbc=pbc_b, @@ -311,7 +310,7 @@ def test_neighbor_list_implementations( atoms_list, device=DEVICE, dtype=DTYPE ) mapping, mapping_system, shifts_idx = nl_implementation( - cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE), + cutoff=cutoff, positions=pos, cell=row_vector_cell, pbc=pbc, @@ -377,7 +376,7 @@ def test_nl_pbc_edge_cases( """ pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE) cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0 - cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE) + cutoff = 1.5 pbc = torch.tensor([pbc_val, pbc_val, pbc_val], device=DEVICE) system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE) @@ -400,7 +399,7 @@ def test_vesin_nl_float32() -> None: [[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=torch.float32 ) cell = torch.eye(3, device=DEVICE, dtype=torch.float32) * 2.0 - cutoff = torch.tensor(1.5, device=DEVICE, dtype=torch.float32) + cutoff = 1.5 pbc = torch.tensor([True, True, True], device=DEVICE) system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE) @@ -412,12 +411,12 @@ def test_vesin_nl_float32() -> None: def _minimal_neighbor_list_inputs( device: torch.device, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, torch.Tensor]: """Create minimal valid tensor inputs for neighbor-list API smoke checks.""" positions = torch.zeros((1, 3), dtype=torch.float32, device=device) cell = torch.eye(3, dtype=torch.float32, device=device) pbc = torch.tensor([False, False, False], dtype=torch.bool, device=device) - cutoff = torch.tensor(1.0, dtype=torch.float32, device=device) + cutoff = 1.0 system_idx = torch.zeros(1, dtype=torch.long, device=device) return positions, cell, pbc, cutoff, system_idx @@ -480,7 +479,7 @@ def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) - ) cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([False, False, False], device=device) - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + cutoff = 1.5 system_idx = torch.zeros(4, dtype=torch.long, device=device) # Use monkeypatch to temporarily disable alchemiops @@ -520,7 +519,7 @@ def test_torchsim_nl_gpu() -> None: ) cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([True, True, True], device=device) - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + cutoff = 1.5 system_idx = torch.zeros(2, dtype=torch.long, device=device) # Should work on GPU regardless of implementation availability @@ -557,7 +556,7 @@ def test_torchsim_nl_fallback_when_vesin_unavailable( ) cell = torch.eye(3, device=device, dtype=dtype) * 3.0 pbc = torch.tensor([False, False, False], device=device) - cutoff = torch.tensor(1.5, device=device, dtype=dtype) + cutoff = 1.5 system_idx = torch.zeros(4, dtype=torch.long, device=device) # Monkeypatch both availability flags to False @@ -582,9 +581,9 @@ def test_torchsim_nl_fallback_when_vesin_unavailable( def _no_neighbor_inputs() -> tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor + torch.Tensor, torch.Tensor, torch.Tensor, float, torch.Tensor ]: - """Build a simple no-neighbor system.""" + """Two atoms far apart in a large non-periodic box; cutoff too small for any pair.""" positions = torch.tensor( [[0.0, 0.0, 0.0], [10.0, 10.0, 10.0]], device=DEVICE, @@ -592,8 +591,21 @@ def _no_neighbor_inputs() -> tuple[ ) cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 20.0 pbc = torch.tensor([False, False, False], device=DEVICE) - cutoff = torch.tensor(1.0, device=DEVICE, dtype=DTYPE) - return positions, cell, pbc, cutoff + cutoff = 1.0 + system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE) + return positions, cell, pbc, cutoff, system_idx + + +@pytest.mark.parametrize("nl_implementation", _all_nl_backends()) +def test_neighbor_list_empty_when_all_pairs_beyond_cutoff( + nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]], +) -> None: + """Every backend returns a valid empty neighbor list when no pair is within cutoff.""" + positions, cell, pbc, cutoff, system_idx = _no_neighbor_inputs() + mapping, sys_map, shifts = nl_implementation(positions, cell, pbc, cutoff, system_idx) + assert mapping.shape == (2, 0) + assert sys_map.numel() == 0 + assert shifts.shape == (0, 3) def test_strict_nl_edge_cases() -> None: diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9671ed973..848d1f81e 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -338,12 +338,11 @@ def determine_max_batch_size( def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]: """Return per-system edge counts from the neighbor list as memory scalers.""" - cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device) _, system_mapping, _ = torchsim_nl( positions=state.positions, cell=state.cell, pbc=state.pbc, - cutoff=cutoff_tensor, + cutoff=cutoff, system_idx=state.system_idx, ) return system_mapping.bincount(minlength=state.n_systems).float().tolist() diff --git a/torch_sim/models/dispersion.py b/torch_sim/models/dispersion.py index 972305565..a60ede7c6 100644 --- a/torch_sim/models/dispersion.py +++ b/torch_sim/models/dispersion.py @@ -20,6 +20,7 @@ import torch +from torch_sim import transforms from torch_sim._duecredit import dcite from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl @@ -141,18 +142,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] self.cutoff, state.system_idx, ) - n_atoms = state.positions.shape[0] - neighbor_ptr = torch.zeros( - n_atoms + 1, dtype=torch.int32, device=state.positions.device - ) - neighbor_ptr[1:] = ( - torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) + edge_index_int, neighbor_ptr, unit_shifts_int = ( + transforms.build_csr_neighbor_list( + edge_index, + _mapping_system, + unit_shifts, + state.positions.shape[0], + ) ) positions_bohr = state.positions * UnitConversion.Ang_to_Bohr cell_bohr = state.row_vector_cell.contiguous() * UnitConversion.Ang_to_Bohr numbers = state.atomic_numbers.to(torch.int32) - unit_shifts_int = unit_shifts.to(torch.int32) - edge_index_int = edge_index.to(torch.int32) d3_out = nvalchemiops_dftd3( positions=positions_bohr, numbers=numbers, diff --git a/torch_sim/models/electrostatics.py b/torch_sim/models/electrostatics.py index 85fb9fd65..2ea9863b9 100644 --- a/torch_sim/models/electrostatics.py +++ b/torch_sim/models/electrostatics.py @@ -17,6 +17,7 @@ particle_mesh_ewald, ) +from torch_sim import transforms from torch_sim._duecredit import dcite from torch_sim.models.interface import ModelInterface from torch_sim.neighbors import torchsim_nl @@ -47,32 +48,6 @@ def _zero_result( return results -def _build_csr( - state: SimState, - cutoff: float, - neighbor_list_fn: Callable, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Build a CSR neighbor list and integer unit-shift tensor.""" - edge_index, _mapping, unit_shifts = neighbor_list_fn( - state.positions, - state.row_vector_cell, - state.pbc, - cutoff, - state.system_idx, - ) - n_atoms = state.positions.shape[0] - dev = state.positions.device - neighbor_ptr = torch.zeros(n_atoms + 1, dtype=torch.int32, device=dev) - neighbor_ptr[1:] = ( - torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32) - ) - return ( - edge_index.to(torch.int32), - neighbor_ptr, - unit_shifts.to(torch.int32), - ) - - class DSFCoulombModel(ModelInterface): """Damped Shifted Force electrostatics as a :class:`ModelInterface`. @@ -133,8 +108,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] raise ValueError("Partial charges are required for DSF Coulomb summation.") charges = state.partial_charges - edge_index, neighbor_ptr, unit_shifts = _build_csr( - state, self.cutoff, self.neighbor_list_fn + edge_index, _mapping, unit_shifts = self.neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + self.cutoff, + state.system_idx, + ) + edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list( + edge_index, + _mapping, + unit_shifts, + state.positions.shape[0], ) cell = state.row_vector_cell.contiguous() dsf_args: dict = dict( @@ -233,8 +218,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] state, self._dtype, self._compute_forces, self._compute_stress ) charges = state.partial_charges - edge_index, neighbor_ptr, unit_shifts = _build_csr( - state, self.cutoff, self.neighbor_list_fn + edge_index, _mapping, unit_shifts = self.neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + self.cutoff, + state.system_idx, + ) + edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list( + edge_index, + _mapping, + unit_shifts, + state.positions.shape[0], ) cell = state.row_vector_cell.contiguous() out = ewald_summation( @@ -345,8 +340,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor] state, self._dtype, self._compute_forces, self._compute_stress ) charges = state.partial_charges - edge_index, neighbor_ptr, unit_shifts = _build_csr( - state, self.cutoff, self.neighbor_list_fn + edge_index, _mapping, unit_shifts = self.neighbor_list_fn( + state.positions, + state.row_vector_cell, + state.pbc, + self.cutoff, + state.system_idx, + ) + edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list( + edge_index, + _mapping, + unit_shifts, + state.positions.shape[0], ) cell = state.row_vector_cell.contiguous() batch_idx = state.system_idx.to(torch.int32) if state.n_systems > 1 else None diff --git a/torch_sim/models/pair_potential.py b/torch_sim/models/pair_potential.py index bfc4d474d..fe3ea9e41 100644 --- a/torch_sim/models/pair_potential.py +++ b/torch_sim/models/pair_potential.py @@ -130,7 +130,7 @@ def full_to_half_list( def _prepare_pairs( state: SimState, *, - cutoff: torch.Tensor, + cutoff: float, neighbor_list_fn: Callable, reduce_to_half_list: bool, device: torch.device, @@ -344,7 +344,7 @@ def __init__( self.per_atom_stresses = per_atom_stresses self.pair_fn = pair_fn self.neighbor_list_fn = neighbor_list_fn - self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device) + self.cutoff = cutoff self.reduce_to_half_list = reduce_to_half_list self.retain_graph = retain_graph @@ -529,7 +529,7 @@ def __init__( self.per_atom_stresses = per_atom_stresses self.force_fn = force_fn self.neighbor_list_fn = neighbor_list_fn - self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device) + self.cutoff = cutoff self.reduce_to_half_list = reduce_to_half_list def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]: diff --git a/torch_sim/neighbors/__init__.py b/torch_sim/neighbors/__init__.py index 5d08c589c..f47f4e315 100644 --- a/torch_sim/neighbors/__init__.py +++ b/torch_sim/neighbors/__init__.py @@ -47,7 +47,7 @@ def torchsim_nl( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -63,7 +63,7 @@ def torchsim_nl( positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors + cutoff: Maximum distance for considering atoms as neighbors system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False diff --git a/torch_sim/neighbors/alchemiops.py b/torch_sim/neighbors/alchemiops.py index f9759c036..565a1173e 100644 --- a/torch_sim/neighbors/alchemiops.py +++ b/torch_sim/neighbors/alchemiops.py @@ -49,7 +49,7 @@ def alchemiops_nl_n2( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -59,14 +59,13 @@ def alchemiops_nl_n2( positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance (scalar tensor) + cutoff: Maximum distance system_idx: Tensor [n_atoms] indicating system assignment self_interaction: If True, include self-pairs Returns: (mapping, system_mapping, shifts_idx) """ - r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff n_systems = int(system_idx.max().item()) + 1 cell, pbc = normalize_inputs(cell, pbc, n_systems) @@ -74,7 +73,7 @@ def alchemiops_nl_n2( raise RuntimeError("nvalchemiops neighbor list is unavailable") res = _batch_naive_neighbor_list( positions=positions, - cutoff=r_max, + cutoff=cutoff, batch_idx=system_idx.to(torch.int32), cell=cell, pbc=pbc.to(torch.bool), @@ -120,7 +119,7 @@ def alchemiops_nl_cell_list( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -130,14 +129,13 @@ def alchemiops_nl_cell_list( positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance (scalar tensor) + cutoff: Maximum distance system_idx: Tensor [n_atoms] indicating system assignment self_interaction: If True, include self-pairs Returns: (mapping, system_mapping, shifts_idx) """ - r_max = cutoff.item() if isinstance(cutoff, torch.Tensor) else cutoff n_systems = int(system_idx.max().item()) + 1 cell, pbc = normalize_inputs(cell, pbc, n_systems) @@ -157,7 +155,7 @@ def alchemiops_nl_cell_list( raise RuntimeError("nvalchemiops cell list is unavailable") res = _batch_cell_list( positions=positions, - cutoff=r_max, + cutoff=cutoff, batch_idx=system_idx.to(torch.int32), cell=cell, pbc=pbc.to(torch.bool), diff --git a/torch_sim/neighbors/torch_nl.py b/torch_sim/neighbors/torch_nl.py index 50bf94c79..d2856f081 100644 --- a/torch_sim/neighbors/torch_nl.py +++ b/torch_sim/neighbors/torch_nl.py @@ -95,7 +95,7 @@ def torch_nl_n2( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -114,7 +114,7 @@ def torch_nl_n2( cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. - cutoff (torch.Tensor): + cutoff (float): The cutoff radius used for the neighbor search. system_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. @@ -144,10 +144,10 @@ def torch_nl_n2( n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_naive_neighborhood( - wrapped, cell, pbc, cutoff.item(), n_atoms, self_interaction + wrapped, cell, pbc, cutoff, n_atoms, self_interaction ) mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), wrapped, cell, mapping, system_mapping, shifts_idx + cutoff, wrapped, cell, mapping, system_mapping, shifts_idx ) shifts_idx = shifts_idx + wrap_shifts[mapping[0]] - wrap_shifts[mapping[1]] return mapping, mapping_system, shifts_idx @@ -157,7 +157,7 @@ def torch_nl_linked_cell( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -177,7 +177,7 @@ def torch_nl_linked_cell( cell (torch.Tensor [n_systems, 3, 3]): Unit cell vectors. pbc (torch.Tensor [n_systems, 3] bool): A tensor indicating the periodic boundary conditions to apply. - cutoff (torch.Tensor): + cutoff (float): The cutoff radius used for the neighbor search. system_idx (torch.Tensor [n_atom,] torch.long): A tensor containing the index of the structure to which each atom belongs. @@ -206,10 +206,10 @@ def torch_nl_linked_cell( n_atoms = torch.bincount(system_idx) mapping, system_mapping, shifts_idx = transforms.build_linked_cell_neighborhood( - wrapped, cell, pbc, cutoff.item(), n_atoms, self_interaction + wrapped, cell, pbc, cutoff, n_atoms, self_interaction ) mapping, mapping_system, shifts_idx = strict_nl( - cutoff.item(), wrapped, cell, mapping, system_mapping, shifts_idx + cutoff, wrapped, cell, mapping, system_mapping, shifts_idx ) shifts_idx = shifts_idx + wrap_shifts[mapping[0]] - wrap_shifts[mapping[1]] return mapping, mapping_system, shifts_idx diff --git a/torch_sim/neighbors/vesin.py b/torch_sim/neighbors/vesin.py index c4165d086..733a25b35 100644 --- a/torch_sim/neighbors/vesin.py +++ b/torch_sim/neighbors/vesin.py @@ -8,6 +8,7 @@ import torch +from torch_sim import transforms from torch_sim.neighbors.utils import normalize_inputs @@ -25,13 +26,14 @@ VESIN_AVAILABLE = VesinNeighborList is not None VESIN_TORCHSCRIPT_AVAILABLE = VesinNeighborListTorch is not None + if VESIN_AVAILABLE: def vesin_nl( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: float | torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -80,6 +82,9 @@ def vesin_nl( dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 cell, pbc = normalize_inputs(cell, pbc, n_systems) + wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) # Process each system's neighbor list separately edge_indices = [] @@ -103,15 +108,16 @@ def vesin_nl( ) # Convert tensors to CPU and float64 without gradients - positions_cpu = positions[system_mask].detach().cpu().to(dtype=torch.float64) + positions_cpu = wrapped[system_mask].detach().cpu().to(dtype=torch.float64) cell_cpu = cell_sys.detach().cpu().to(dtype=torch.float64) periodic_cpu = pbc[sys_idx].detach().to(dtype=torch.bool).cpu() + periodic_bool = bool(torch.all(periodic_cpu).item()) # Only works on CPU and returns numpy arrays i, j, S = neighbor_list_fn.compute( points=positions_cpu, box=cell_cpu, - periodic=periodic_cpu, + periodic=periodic_bool, quantities="ijS", ) i, j = ( @@ -123,6 +129,9 @@ def vesin_nl( # Adjust indices for the global atom indexing edge_idx = edge_idx + offset + shifts = shifts + (wrap_shifts[edge_idx[0]] - wrap_shifts[edge_idx[1]]).to( + dtype=dtype + ) edge_indices.append(edge_idx) shifts_idx_list.append(shifts) @@ -173,7 +182,7 @@ def vesin_nl_ts( positions: torch.Tensor, cell: torch.Tensor, pbc: torch.Tensor, - cutoff: torch.Tensor, + cutoff: float, system_idx: torch.Tensor, self_interaction: bool = False, # noqa: FBT001, FBT002 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -186,7 +195,7 @@ def vesin_nl_ts( positions: Atomic positions tensor [n_atoms, 3] cell: Unit cell vectors [n_systems, 3, 3] or [3, 3] pbc: Boolean tensor [n_systems, 3] or [3] - cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors + cutoff: Maximum distance for considering atoms as neighbors system_idx: Tensor [n_atoms] indicating which system each atom belongs to self_interaction: If True, include self-pairs. Default: False @@ -220,6 +229,9 @@ def vesin_nl_ts( dtype = positions.dtype n_systems = int(system_idx.max().item()) + 1 cell, pbc = normalize_inputs(cell, pbc, n_systems) + wrapped, wrap_shifts = transforms.pbc_wrap_batched_and_get_lattice_shifts( + positions, cell, system_idx, pbc + ) # Process each system's neighbor list separately edge_indices = [] @@ -235,13 +247,13 @@ def vesin_nl_ts( continue # Calculate neighbor list for this system - neighbor_list_fn = VesinNeighborListTorch(cutoff.item(), full_list=True) + neighbor_list_fn = VesinNeighborListTorch(cutoff, full_list=True) # Get the cell for this system cell_sys = cell[sys_idx] # Convert tensors to CPU and float64 properly - positions_cpu = positions[system_mask].cpu().to(dtype=torch.float64) + positions_cpu = wrapped[system_mask].cpu().to(dtype=torch.float64) cell_cpu = cell_sys.cpu().to(dtype=torch.float64) periodic_cpu = pbc[sys_idx].to(dtype=torch.bool).cpu() @@ -258,6 +270,9 @@ def vesin_nl_ts( # Adjust indices for the global atom indexing edge_idx = edge_idx + offset + shifts = shifts + (wrap_shifts[edge_idx[0]] - wrap_shifts[edge_idx[1]]).to( + dtype=dtype + ) edge_indices.append(edge_idx) shifts_idx_list.append(shifts) diff --git a/torch_sim/transforms.py b/torch_sim/transforms.py index d8221571b..f4f51189c 100644 --- a/torch_sim/transforms.py +++ b/torch_sim/transforms.py @@ -1370,6 +1370,36 @@ def build_linked_cell_neighborhood( ) +def sort_neighbors_for_csr( + mapping: torch.Tensor, + system_mapping: torch.Tensor, + shifts_idx: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Group neighbor-list entries by central atom for CSR kernels.""" + if mapping.shape[1] == 0: + return mapping, system_mapping, shifts_idx + original_order = torch.arange(mapping.shape[1], device=mapping.device) + order = torch.argsort(mapping[0] * mapping.shape[1] + original_order) + return mapping[:, order], system_mapping[order], shifts_idx[order] + + +def build_csr_neighbor_list( + mapping: torch.Tensor, + system_mapping: torch.Tensor, + shifts_idx: torch.Tensor, + n_atoms: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build CSR inputs for kernels that require neighbors grouped by center atom.""" + mapping, _system_mapping, shifts_idx = sort_neighbors_for_csr( + mapping, system_mapping, shifts_idx + ) + neighbor_ptr = torch.zeros(n_atoms + 1, dtype=torch.int32, device=mapping.device) + neighbor_ptr[1:] = ( + torch.bincount(mapping[0], minlength=n_atoms).cumsum(0).to(torch.int32) + ) + return mapping.to(torch.int32), neighbor_ptr, shifts_idx.to(torch.int32) + + def multiplicative_isotropic_cutoff( fn: Callable[..., torch.Tensor], r_onset: float | torch.Tensor,