Skip to content

Commit 9544e14

Browse files
committed
Wrote a compact test for the transpose functions.
1 parent b5866a2 commit 9544e14

1 file changed

Lines changed: 47 additions & 3 deletions

File tree

tests/batch_test.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,8 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
273273
return tp, tp.config
274274

275275

276-
class TestIrMulLayoutMACE(TPCorrectness):
277-
production_model_tpps = mace_problems() + [
276+
class TestIrMul(TPCorrectness):
277+
tpps = mace_problems() + [
278278
oeq.TPProblem(
279279
"5x5e",
280280
"1x3e",
@@ -295,13 +295,57 @@ class TestIrMulLayoutMACE(TPCorrectness):
295295
),
296296
]
297297

298-
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
298+
@pytest.fixture(params=tpps, ids=lambda x: x.label, scope="class")
299299
def problem(self, request, dtype):
300300
problem = request.param.clone()
301301
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
302302
problem.layout = "ir_mul"
303303
return problem
304304

305+
@pytest.fixture(params=["native", "transpose_wrapper"], scope="class")
306+
def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax):
307+
mode = request.param
308+
309+
if mode == "native":
310+
cls = oeq.TensorProduct
311+
if with_jax:
312+
import openequivariance.jax.TensorProduct as jax_tp
313+
314+
cls = jax_tp
315+
tp = cls(problem, **extra_tp_constructor_args)
316+
return tp, problem
317+
else:
318+
if with_jax:
319+
import openequivariance.jax.TensorProduct as jax_tp
320+
from openequivariance.jax import transpose_irreps
321+
322+
tp_base_cls = jax_tp
323+
else:
324+
from openequivariance._torch.utils import transpose_irreps
325+
326+
tp_base_cls = oeq.TensorProduct
327+
328+
class TransposeWrapperTensorProduct(tp_base_cls):
329+
def forward(self, x, y, W):
330+
x_t = transpose_irreps(
331+
x, self.config.irreps_in1, "ir_mul", "mul_ir"
332+
)
333+
y_t = transpose_irreps(
334+
y, self.config.irreps_in2, "ir_mul", "mul_ir"
335+
)
336+
out_mul_ir = super().forward(x_t, y_t, W)
337+
return transpose_irreps(
338+
out_mul_ir, self.config.irreps_out, "mul_ir", "ir_mul"
339+
)
340+
341+
wrapped_problem = problem.clone()
342+
wrapped_problem.layout = "mul_ir"
343+
tp = TransposeWrapperTensorProduct(
344+
wrapped_problem, **extra_tp_constructor_args
345+
)
346+
tp.config.layout = "ir_mul"
347+
return tp, problem
348+
305349

306350
class TestTorchToSubmodule:
307351
"""Test that TensorProduct works correctly as a submodule when parent's .to() is called"""

0 commit comments

Comments
 (0)