Skip to content

Commit 26b746b

Browse files
committed
Almost there.
1 parent 45f236a commit 26b746b

1 file changed

Lines changed: 21 additions & 29 deletions

File tree

tests/batch_test.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
correctness_double_backward,
99
correctness_forward,
1010
)
11+
from openequivariance.benchmark.test_buffers import get_random_buffers_forward
1112
from openequivariance.benchmark.problems import (
1213
diffdock_problems,
1314
e3nn_torch_tetris_poly_problems,
@@ -274,9 +275,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274275

275276

276277
class TestIrMul(TPCorrectness):
277-
'''
278-
Tests both the ir_mul layout and the transpose_irreps functions.
279-
'''
278+
"""
279+
Tests both the ir_mul layout and the transpose_irreps functions.
280+
"""
281+
280282
tpps = mace_problems() + [
281283
oeq.TPProblem(
282284
"5x5e",
@@ -323,6 +325,7 @@ def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax):
323325
tp = tp_base_cls(problem, **extra_tp_constructor_args)
324326
return tp, problem
325327
else:
328+
326329
class TransposeWrapperTensorProduct(tp_base_cls):
327330
def forward(self, x, y, W):
328331
x_t = transpose_irreps(
@@ -370,40 +373,29 @@ def forward(self, x, y, w):
370373
def _problem_dtype(self, problem):
371374
return torch.float32 if problem.irrep_dtype == np.float32 else torch.float64
372375

373-
def _make_inputs(self, problem, batch_size, rng, dtype, device):
374-
in1 = torch.tensor(
375-
rng.uniform(size=(batch_size, problem.irreps_in1.dim)),
376-
dtype=dtype,
377-
device=device,
378-
)
379-
in2 = torch.tensor(
380-
rng.uniform(size=(batch_size, problem.irreps_in2.dim)),
381-
dtype=dtype,
382-
device=device,
383-
)
384-
weights_size = (
385-
(problem.weight_numel,)
386-
if problem.shared_weights
387-
else (batch_size, problem.weight_numel)
388-
)
389-
weights = torch.tensor(
390-
rng.uniform(size=weights_size),
391-
dtype=dtype,
392-
device=device,
376+
def _make_inputs(self, problem, batch_size, dtype, device, prng_seed=12345):
377+
dtype_map = {torch.float32: np.float32, torch.float64: np.float64}
378+
buffer_problem = problem.clone()
379+
buffer_problem.irrep_dtype = dtype_map[dtype]
380+
buffer_problem.weight_dtype = dtype_map[dtype]
381+
382+
in1_np, in2_np, weights_np, _ = get_random_buffers_forward(
383+
buffer_problem, batch_size=batch_size, prng_seed=prng_seed
393384
)
394-
return in1, in2, weights
385+
386+
return [
387+
torch.tensor(arr, dtype=dtype, device=device)
388+
for arr in [in1_np, in2_np, weights_np]
389+
]
395390

396391
def test_submodule_dtype_conversion(self, parent_module_and_problem):
397392
"""Test that calling .to() on parent module properly converts TensorProduct submodule"""
398393
parent, problem = parent_module_and_problem
399394

400395
batch_size = 10
401-
rng = np.random.default_rng(12345)
402396
device = "cuda"
403397
input_dtype = self._problem_dtype(problem)
404-
in1, in2, weights = self._make_inputs(
405-
problem, batch_size, rng, input_dtype, device
406-
)
398+
in1, in2, weights = self._make_inputs(problem, batch_size, input_dtype, device)
407399

408400
output1 = parent(in1, in2, weights)
409401
assert output1.dtype == in1.dtype, (
@@ -418,7 +410,7 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem):
418410
parent.to(target_dtype)
419411

420412
in1_new, in2_new, weights_new = self._make_inputs(
421-
problem, batch_size, rng, target_dtype, device
413+
problem, batch_size, target_dtype, device, prng_seed=23456
422414
)
423415

424416
output2 = parent(in1_new, in2_new, weights_new)

0 commit comments

Comments
 (0)