Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/benchmarking/neighborlists.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def _benchmark_backend(
positions: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
cutoff: torch.Tensor,
cutoff: float,
system_idx: torch.Tensor,
n_repeats: int,
device: str,
Expand Down Expand Up @@ -291,7 +291,7 @@ def run_benchmark(args: argparse.Namespace) -> dict[str, Any]:
positions, cell, pbc, system_idx = _build_tensors(
structures, dtype=torch_dtype, device=args.device
)
cutoff = torch.tensor(args.cutoff, dtype=torch_dtype, device=args.device)
cutoff = float(args.cutoff)
n_atoms = positions.shape[0]

backends = args.nl_backend if isinstance(args.nl_backend, list) else [args.nl_backend]
Expand Down
2 changes: 1 addition & 1 deletion examples/scripts/7_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
pos, cell, pbc = state.positions, state.cell, state.pbc
system_idx = state.system_idx
n_atoms = state.n_atoms
cutoff = torch.tensor(4.0, dtype=pos.dtype)
cutoff = 4.0
self_interaction = False

# Ensure pbc has the correct shape [n_systems, 3]
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_soft_sphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_multispecies_initialization_custom() -> None:
assert torch.allclose(model.sigma_matrix, sigma_matrix)
assert torch.allclose(model.epsilon_matrix, epsilon_matrix)
assert torch.allclose(model.alpha_matrix, alpha_matrix)
assert model.cutoff.item() == 3.0
assert model.cutoff == 3.0


def test_multispecies_matrix_validation() -> None:
Expand Down Expand Up @@ -196,7 +196,7 @@ def test_multispecies_cutoff_default() -> None:
model = SoftSphereMultiModel(
atomic_numbers=torch.tensor([0, 1, 2]), sigma_matrix=sigma_matrix
)
assert model.cutoff.item() == 3.0
assert model.cutoff == 3.0


def test_multispecies_evaluation() -> None:
Expand Down
42 changes: 27 additions & 15 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,16 @@ def test_neighbor_list_invariant_under_lattice_image_shifts(
assert not torch.allclose(pos_shifted, pos_wrapped, rtol=0.0, atol=1e-12), (
"expected non-trivial lattice shifts along periodic axes"
)
c_tensor = torch.tensor(cutoff, dtype=DTYPE, device=DEVICE)
map_w, sys_w, sh_w = nl_implementation(
cutoff=c_tensor,
cutoff=cutoff,
positions=pos_wrapped,
cell=cell_b,
pbc=pbc_b,
system_idx=batch,
self_interaction=self_interaction,
)
map_s, sys_s, sh_s = nl_implementation(
cutoff=c_tensor,
cutoff=cutoff,
positions=pos_shifted,
cell=cell_b,
pbc=pbc_b,
Expand Down Expand Up @@ -311,7 +310,7 @@ def test_neighbor_list_implementations(
atoms_list, device=DEVICE, dtype=DTYPE
)
mapping, mapping_system, shifts_idx = nl_implementation(
cutoff=torch.tensor(cutoff, dtype=DTYPE, device=DEVICE),
cutoff=cutoff,
positions=pos,
cell=row_vector_cell,
pbc=pbc,
Expand Down Expand Up @@ -377,7 +376,7 @@ def test_nl_pbc_edge_cases(
"""
pos = torch.tensor([[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=DTYPE)
cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 2.0
cutoff = torch.tensor(1.5, device=DEVICE, dtype=DTYPE)
cutoff = 1.5
pbc = torch.tensor([pbc_val, pbc_val, pbc_val], device=DEVICE)
system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE)

Expand All @@ -400,7 +399,7 @@ def test_vesin_nl_float32() -> None:
[[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]], device=DEVICE, dtype=torch.float32
)
cell = torch.eye(3, device=DEVICE, dtype=torch.float32) * 2.0
cutoff = torch.tensor(1.5, device=DEVICE, dtype=torch.float32)
cutoff = 1.5
pbc = torch.tensor([True, True, True], device=DEVICE)
system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE)

Expand All @@ -412,12 +411,12 @@ def test_vesin_nl_float32() -> None:

def _minimal_neighbor_list_inputs(
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, float, torch.Tensor]:
"""Create minimal valid tensor inputs for neighbor-list API smoke checks."""
positions = torch.zeros((1, 3), dtype=torch.float32, device=device)
cell = torch.eye(3, dtype=torch.float32, device=device)
pbc = torch.tensor([False, False, False], dtype=torch.bool, device=device)
cutoff = torch.tensor(1.0, dtype=torch.float32, device=device)
cutoff = 1.0
system_idx = torch.zeros(1, dtype=torch.long, device=device)
return positions, cell, pbc, cutoff, system_idx

Expand Down Expand Up @@ -480,7 +479,7 @@ def test_fallback_when_alchemiops_unavailable(monkeypatch: pytest.MonkeyPatch) -
)
cell = torch.eye(3, device=device, dtype=dtype) * 3.0
pbc = torch.tensor([False, False, False], device=device)
cutoff = torch.tensor(1.5, device=device, dtype=dtype)
cutoff = 1.5
system_idx = torch.zeros(4, dtype=torch.long, device=device)

# Use monkeypatch to temporarily disable alchemiops
Expand Down Expand Up @@ -520,7 +519,7 @@ def test_torchsim_nl_gpu() -> None:
)
cell = torch.eye(3, device=device, dtype=dtype) * 3.0
pbc = torch.tensor([True, True, True], device=device)
cutoff = torch.tensor(1.5, device=device, dtype=dtype)
cutoff = 1.5
system_idx = torch.zeros(2, dtype=torch.long, device=device)

# Should work on GPU regardless of implementation availability
Expand Down Expand Up @@ -557,7 +556,7 @@ def test_torchsim_nl_fallback_when_vesin_unavailable(
)
cell = torch.eye(3, device=device, dtype=dtype) * 3.0
pbc = torch.tensor([False, False, False], device=device)
cutoff = torch.tensor(1.5, device=device, dtype=dtype)
cutoff = 1.5
system_idx = torch.zeros(4, dtype=torch.long, device=device)

# Monkeypatch both availability flags to False
Expand All @@ -582,18 +581,31 @@ def test_torchsim_nl_fallback_when_vesin_unavailable(


def _no_neighbor_inputs() -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
torch.Tensor, torch.Tensor, torch.Tensor, float, torch.Tensor
]:
"""Build a simple no-neighbor system."""
"""Two atoms far apart in a large non-periodic box; cutoff too small for any pair."""
positions = torch.tensor(
[[0.0, 0.0, 0.0], [10.0, 10.0, 10.0]],
device=DEVICE,
dtype=DTYPE,
)
cell = torch.eye(3, device=DEVICE, dtype=DTYPE) * 20.0
pbc = torch.tensor([False, False, False], device=DEVICE)
cutoff = torch.tensor(1.0, device=DEVICE, dtype=DTYPE)
return positions, cell, pbc, cutoff
cutoff = 1.0
system_idx = torch.zeros(2, dtype=torch.long, device=DEVICE)
return positions, cell, pbc, cutoff, system_idx


@pytest.mark.parametrize("nl_implementation", _all_nl_backends())
def test_neighbor_list_empty_when_all_pairs_beyond_cutoff(
nl_implementation: Callable[..., tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
) -> None:
"""Every backend returns a valid empty neighbor list when no pair is within cutoff."""
positions, cell, pbc, cutoff, system_idx = _no_neighbor_inputs()
mapping, sys_map, shifts = nl_implementation(positions, cell, pbc, cutoff, system_idx)
assert mapping.shape == (2, 0)
assert sys_map.numel() == 0
assert shifts.shape == (0, 3)


def test_strict_nl_edge_cases() -> None:
Expand Down
3 changes: 1 addition & 2 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,11 @@ def determine_max_batch_size(

def _n_edges_scalers(state: SimState, cutoff: float) -> list[float]:
"""Return per-system edge counts from the neighbor list as memory scalers."""
cutoff_tensor = torch.tensor(cutoff, dtype=state.dtype, device=state.device)
_, system_mapping, _ = torchsim_nl(
positions=state.positions,
cell=state.cell,
pbc=state.pbc,
cutoff=cutoff_tensor,
cutoff=cutoff,
system_idx=state.system_idx,
)
return system_mapping.bincount(minlength=state.n_systems).float().tolist()
Expand Down
16 changes: 8 additions & 8 deletions torch_sim/models/dispersion.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import torch

from torch_sim import transforms
from torch_sim._duecredit import dcite
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl
Expand Down Expand Up @@ -141,18 +142,17 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
self.cutoff,
state.system_idx,
)
n_atoms = state.positions.shape[0]
neighbor_ptr = torch.zeros(
n_atoms + 1, dtype=torch.int32, device=state.positions.device
)
neighbor_ptr[1:] = (
torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32)
edge_index_int, neighbor_ptr, unit_shifts_int = (
transforms.build_csr_neighbor_list(
edge_index,
_mapping_system,
unit_shifts,
state.positions.shape[0],
)
)
positions_bohr = state.positions * UnitConversion.Ang_to_Bohr
cell_bohr = state.row_vector_cell.contiguous() * UnitConversion.Ang_to_Bohr
numbers = state.atomic_numbers.to(torch.int32)
unit_shifts_int = unit_shifts.to(torch.int32)
edge_index_int = edge_index.to(torch.int32)
d3_out = nvalchemiops_dftd3(
positions=positions_bohr,
numbers=numbers,
Expand Down
69 changes: 37 additions & 32 deletions torch_sim/models/electrostatics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
particle_mesh_ewald,
)

from torch_sim import transforms
from torch_sim._duecredit import dcite
from torch_sim.models.interface import ModelInterface
from torch_sim.neighbors import torchsim_nl
Expand Down Expand Up @@ -47,32 +48,6 @@ def _zero_result(
return results


def _build_csr(
state: SimState,
cutoff: float,
neighbor_list_fn: Callable,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build a CSR neighbor list and integer unit-shift tensor."""
edge_index, _mapping, unit_shifts = neighbor_list_fn(
state.positions,
state.row_vector_cell,
state.pbc,
cutoff,
state.system_idx,
)
n_atoms = state.positions.shape[0]
dev = state.positions.device
neighbor_ptr = torch.zeros(n_atoms + 1, dtype=torch.int32, device=dev)
neighbor_ptr[1:] = (
torch.bincount(edge_index[0], minlength=n_atoms).cumsum(0).to(torch.int32)
)
return (
edge_index.to(torch.int32),
neighbor_ptr,
unit_shifts.to(torch.int32),
)


class DSFCoulombModel(ModelInterface):
"""Damped Shifted Force electrostatics as a :class:`ModelInterface`.

Expand Down Expand Up @@ -133,8 +108,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
raise ValueError("Partial charges are required for DSF Coulomb summation.")

charges = state.partial_charges
edge_index, neighbor_ptr, unit_shifts = _build_csr(
state, self.cutoff, self.neighbor_list_fn
edge_index, _mapping, unit_shifts = self.neighbor_list_fn(
state.positions,
state.row_vector_cell,
state.pbc,
self.cutoff,
state.system_idx,
)
edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list(
edge_index,
_mapping,
unit_shifts,
state.positions.shape[0],
)
cell = state.row_vector_cell.contiguous()
dsf_args: dict = dict(
Expand Down Expand Up @@ -233,8 +218,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
state, self._dtype, self._compute_forces, self._compute_stress
)
charges = state.partial_charges
edge_index, neighbor_ptr, unit_shifts = _build_csr(
state, self.cutoff, self.neighbor_list_fn
edge_index, _mapping, unit_shifts = self.neighbor_list_fn(
state.positions,
state.row_vector_cell,
state.pbc,
self.cutoff,
state.system_idx,
)
edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list(
edge_index,
_mapping,
unit_shifts,
state.positions.shape[0],
)
cell = state.row_vector_cell.contiguous()
out = ewald_summation(
Expand Down Expand Up @@ -345,8 +340,18 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
state, self._dtype, self._compute_forces, self._compute_stress
)
charges = state.partial_charges
edge_index, neighbor_ptr, unit_shifts = _build_csr(
state, self.cutoff, self.neighbor_list_fn
edge_index, _mapping, unit_shifts = self.neighbor_list_fn(
state.positions,
state.row_vector_cell,
state.pbc,
self.cutoff,
state.system_idx,
)
edge_index, neighbor_ptr, unit_shifts = transforms.build_csr_neighbor_list(
edge_index,
_mapping,
unit_shifts,
state.positions.shape[0],
)
cell = state.row_vector_cell.contiguous()
batch_idx = state.system_idx.to(torch.int32) if state.n_systems > 1 else None
Expand Down
6 changes: 3 additions & 3 deletions torch_sim/models/pair_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def full_to_half_list(
def _prepare_pairs(
state: SimState,
*,
cutoff: torch.Tensor,
cutoff: float,
neighbor_list_fn: Callable,
reduce_to_half_list: bool,
device: torch.device,
Expand Down Expand Up @@ -344,7 +344,7 @@ def __init__(
self.per_atom_stresses = per_atom_stresses
self.pair_fn = pair_fn
self.neighbor_list_fn = neighbor_list_fn
self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device)
self.cutoff = cutoff
self.reduce_to_half_list = reduce_to_half_list
self.retain_graph = retain_graph

Expand Down Expand Up @@ -529,7 +529,7 @@ def __init__(
self.per_atom_stresses = per_atom_stresses
self.force_fn = force_fn
self.neighbor_list_fn = neighbor_list_fn
self.cutoff = torch.tensor(cutoff, dtype=dtype, device=self._device)
self.cutoff = cutoff
self.reduce_to_half_list = reduce_to_half_list

def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]:
Expand Down
4 changes: 2 additions & 2 deletions torch_sim/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def torchsim_nl(
positions: torch.Tensor,
cell: torch.Tensor,
pbc: torch.Tensor,
cutoff: torch.Tensor,
cutoff: float,
system_idx: torch.Tensor,
self_interaction: bool = False, # noqa: FBT001, FBT002
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
Expand All @@ -63,7 +63,7 @@ def torchsim_nl(
positions: Atomic positions tensor [n_atoms, 3]
cell: Unit cell vectors [n_systems, 3, 3] or [3, 3]
pbc: Boolean tensor [n_systems, 3] or [3]
cutoff: Maximum distance (scalar tensor) for considering atoms as neighbors
cutoff: Maximum distance for considering atoms as neighbors
system_idx: Tensor [n_atoms] indicating which system each atom belongs to
self_interaction: If True, include self-pairs. Default: False

Expand Down
Loading
Loading