Skip to content

Commit 5440d55

Browse files
committed
Compacted diff further.
1 parent 9544e14 commit 5440d55

1 file changed

Lines changed: 14 additions & 15 deletions

File tree

tests/batch_test.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274274

275275

276276
class TestIrMul(TPCorrectness):
277+
'''
278+
Tests both the ir_mul layout and the transpose_irreps functions
279+
via a wrapper.
280+
'''
277281
tpps = mace_problems() + [
278282
oeq.TPProblem(
279283
"5x5e",
@@ -306,25 +310,20 @@ def problem(self, request, dtype):
306310
def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax):
307311
mode = request.param
308312

309-
if mode == "native":
310-
cls = oeq.TensorProduct
311-
if with_jax:
312-
import openequivariance.jax.TensorProduct as jax_tp
313+
if with_jax:
314+
import openequivariance.jax.TensorProduct as jax_tp
315+
from openequivariance.jax import transpose_irreps
313316

314-
cls = jax_tp
315-
tp = cls(problem, **extra_tp_constructor_args)
316-
return tp, problem
317+
tp_base_cls = jax_tp
317318
else:
318-
if with_jax:
319-
import openequivariance.jax.TensorProduct as jax_tp
320-
from openequivariance.jax import transpose_irreps
319+
from openequivariance._torch.utils import transpose_irreps
321320

322-
tp_base_cls = jax_tp
323-
else:
324-
from openequivariance._torch.utils import transpose_irreps
325-
326-
tp_base_cls = oeq.TensorProduct
321+
tp_base_cls = oeq.TensorProduct
327322

323+
if mode == "native":
324+
tp = tp_base_cls(problem, **extra_tp_constructor_args)
325+
return tp, problem
326+
else:
328327
class TransposeWrapperTensorProduct(tp_base_cls):
329328
def forward(self, x, y, W):
330329
x_t = transpose_irreps(

0 commit comments

Comments
 (0)