Skip to content

Commit 5ce90e5

Browse files
v0.6.0 Release Prep (#185)
* Submodule to() tests. * Updated changelog. * Minor fixes to changelog. * Removed version mismatch error. * Removed some extraneous comments.
1 parent 772781e commit 5ce90e5

9 files changed

Lines changed: 302 additions & 44 deletions

File tree

CHANGELOG.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,30 @@
11
## Latest Changes
22

3+
### v0.6.0 (2025-02-23)
4+
OpenEquivariance v0.6.0 brings long-needed improvements to the
5+
PyTorch frontend. We strongly encourage all users to upgrade
6+
to PyTorch 2.10 and OEQ v0.6.0.
7+
8+
**Added**:
9+
- OpenEquivariance triggers a build of the CUDA extension module
10+
at `pip` install time and will use this precompiled extension if
11+
the user has PyTorch >=2.10 installed. If PyTorch <2.10 is installed,
12+
the JIT-compiled extension is used instead.
13+
- PyTorch ABI support for C++ backend, using new features in PyTorch
14+
2.10 to support stable, forward-compatible ahead-of-time
15+
extensions.
16+
- Dropped support for TorchBind classes and a new kernel cache in its
17+
place, which greatly improves flexibility for automatic mixed precision
18+
and AOTI compilation. An inference test in C++ is included.
19+
- `openequivariance_extjax` has a version number that synchronizes with
20+
the main `openequivariance` package; ensure the two packages stay in sync.
21+
22+
**Fixed**:
23+
- `torch.to()` is now called when either `TensorProduct`
24+
or `TensorProductConv` is a submodule of another PyTorch
25+
module.
26+
27+
328
### v0.5.4 (2025-02-01)
429
Improvements to JAX frontend.
530

openequivariance/openequivariance/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import numpy as np
55

66
from pathlib import Path
7+
import warnings
78
from importlib.metadata import version
89

910
from openequivariance.core.e3nn_lite import (
@@ -80,6 +81,15 @@ def torch_ext_so_path():
8081
try:
8182
import openequivariance_extjax
8283
import openequivariance.jax as jax
84+
85+
# TODO-someday: enable
86+
# extjax_version = version("openequivariance_extjax")
87+
# if extjax_version != __version__:
88+
# warnings.warn(
89+
# f"openequivariance_extjax version {extjax_version} does not match "
90+
# f"openequivariance version {__version__}. Ensure both versions match."
91+
# )
92+
8393
except Exception as e:
8494
error = e
8595

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,13 @@
22
from openequivariance import TPProblem
33
from openequivariance._torch import extlib
44
import torch
5-
from openequivariance.core.utils import torch_to_oeq_dtype
5+
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
66
from openequivariance.benchmark.logging_utils import getLogger
7-
from openequivariance._torch.utils import reorder_torch, string_to_tensor
7+
from openequivariance._torch.utils import (
8+
reorder_torch,
9+
string_to_tensor,
10+
enum_to_torch_dtype,
11+
)
812
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
913

1014
import numpy as np
@@ -66,6 +70,27 @@ def to(self, *args, **kwargs):
6670
torch.nn.Module.to(self, *args, **kwargs)
6771
return self
6872

73+
def _apply(self, fn, recurse=True):
74+
if getattr(self, "_applying", False):
75+
return super()._apply(fn, recurse)
76+
77+
problem: TPProblem = self.input_args["problem"]
78+
irrep_dtype = problem.irrep_dtype
79+
80+
if irrep_dtype in dtype_to_enum:
81+
irrep_dtype = dtype_to_enum[irrep_dtype]
82+
83+
current_dtype = enum_to_torch_dtype[irrep_dtype]
84+
dummy = torch.tensor(0.0, dtype=current_dtype)
85+
result = fn(dummy)
86+
87+
if result.dtype != current_dtype:
88+
self._applying = True
89+
self.to(result.dtype)
90+
self._applying = False
91+
92+
return super()._apply(fn, recurse)
93+
6994
def __getstate__(self):
7095
return self.input_args
7196

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
1717
from openequivariance._torch.TensorProduct import TensorProduct
1818
from openequivariance import TPProblem
19-
from openequivariance.core.utils import torch_to_oeq_dtype
19+
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
2020
from openequivariance._torch.utils import (
2121
reorder_torch,
2222
string_to_tensor,
23+
enum_to_torch_dtype,
2324
)
2425

2526
from openequivariance.benchmark.logging_utils import getLogger
@@ -109,6 +110,27 @@ def to(self, *args, **kwargs):
109110
torch.nn.Module.to(self, *args, **kwargs)
110111
return self
111112

113+
def _apply(self, fn, recurse=True):
114+
if getattr(self, "_applying", False):
115+
return super()._apply(fn, recurse)
116+
117+
problem: TPProblem = self.input_args["problem"]
118+
irrep_dtype = problem.irrep_dtype
119+
120+
if irrep_dtype in dtype_to_enum:
121+
irrep_dtype = dtype_to_enum[irrep_dtype]
122+
123+
current_dtype = enum_to_torch_dtype[irrep_dtype]
124+
dummy = torch.tensor(0.0, dtype=current_dtype)
125+
result = fn(dummy)
126+
127+
if result.dtype != current_dtype:
128+
self._applying = True
129+
self.to(result.dtype)
130+
self._applying = False
131+
132+
return super()._apply(fn, recurse)
133+
112134
def __getstate__(self):
113135
return self.input_args
114136

tests/batch_test.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
import torch
2020

2121

22+
@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="module")
23+
def dtype(request):
24+
return request.param
25+
26+
2227
class TPCorrectness:
2328
def thresh(self, direction):
2429
return {"fwd": 1e-5, "bwd": 3e-4, "double_bwd": 3e-4}[direction]
@@ -31,18 +36,10 @@ def check_result(self, result, fieldname):
3136
f"{fieldname} observed error={error:.5f} >= {thresh}"
3237
)
3338

34-
@pytest.fixture(params=[np.float32, np.float64], ids=["F32", "F64"], scope="class")
35-
def dtype(self, request):
36-
return request.param
37-
3839
@pytest.fixture(scope="class")
3940
def extra_tp_constructor_args(self):
4041
return {}
4142

42-
@pytest.fixture(scope="class")
43-
def with_jax(self, request):
44-
return request.config.getoption("--jax")
45-
4643
@pytest.fixture(scope="class")
4744
def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
4845
cls = oeq.TensorProduct
@@ -274,3 +271,85 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274271
}
275272
tp.to(switch_map[problem.irrep_dtype])
276273
return tp, tp.config
274+
275+
276+
class TestTorchToSubmodule:
277+
"""Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
278+
279+
@pytest.fixture(scope="class")
280+
def parent_module_and_problem(self, dtype, with_jax):
281+
if with_jax:
282+
pytest.skip("N/A for JAX")
283+
284+
problem = mace_problems()[0].clone()
285+
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
286+
287+
class ParentModule(torch.nn.Module):
288+
def __init__(self, problem):
289+
super().__init__()
290+
self.tp = oeq.TensorProduct(problem)
291+
292+
def forward(self, x, y, w):
293+
return self.tp(x, y, w)
294+
295+
parent = ParentModule(problem)
296+
return parent, problem
297+
298+
def _problem_dtype(self, problem):
299+
return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64
300+
301+
def _make_inputs(self, problem, batch_size, rng, dtype, device):
302+
in1 = torch.tensor(
303+
rng.uniform(size=(batch_size, problem.irreps_in1.dim)),
304+
dtype=dtype,
305+
device=device,
306+
)
307+
in2 = torch.tensor(
308+
rng.uniform(size=(batch_size, problem.irreps_in2.dim)),
309+
dtype=dtype,
310+
device=device,
311+
)
312+
weights_size = (
313+
(problem.weight_numel,)
314+
if problem.shared_weights
315+
else (batch_size, problem.weight_numel)
316+
)
317+
weights = torch.tensor(
318+
rng.uniform(size=weights_size),
319+
dtype=dtype,
320+
device=device,
321+
)
322+
return in1, in2, weights
323+
324+
def test_submodule_dtype_conversion(self, parent_module_and_problem):
325+
"""Test that calling .to() on parent module properly converts TensorProduct submodule"""
326+
parent, problem = parent_module_and_problem
327+
328+
batch_size = 10
329+
rng = np.random.default_rng(12345)
330+
device = "cuda"
331+
input_dtype = self._problem_dtype(problem)
332+
in1, in2, weights = self._make_inputs(
333+
problem, batch_size, rng, input_dtype, device
334+
)
335+
336+
output1 = parent(in1, in2, weights)
337+
assert output1.dtype == in1.dtype, (
338+
f"Expected output dtype {in1.dtype}, got {output1.dtype}"
339+
)
340+
341+
switch_map = {
342+
np.float32: torch.float64,
343+
np.float64: torch.float32,
344+
}
345+
target_dtype = switch_map[problem.irrep_dtype]
346+
parent.to(target_dtype)
347+
348+
in1_new, in2_new, weights_new = self._make_inputs(
349+
problem, batch_size, rng, target_dtype, device
350+
)
351+
352+
output2 = parent(in1_new, in2_new, weights_new)
353+
assert output2.dtype == target_dtype, (
354+
f"Expected output dtype {target_dtype}, got {output2.dtype}"
355+
)

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import pytest
23

34
os.environ["JAX_ENABLE_X64"] = "True"
45
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
@@ -12,3 +13,8 @@ def pytest_addoption(parser):
1213
default=False,
1314
help="Test the JAX frontend instead of PyTorch",
1415
)
16+
17+
18+
@pytest.fixture(scope="session")
19+
def with_jax(request):
20+
return request.config.getoption("--jax")

0 commit comments

Comments
 (0)