@@ -274,6 +274,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274274
275275
276276class TestIrMul (TPCorrectness ):
277+ '''
278+ Tests both the ir_mul layout and the transpose_irreps functions
279+ via a wrapper.
280+ '''
277281 tpps = mace_problems () + [
278282 oeq .TPProblem (
279283 "5x5e" ,
@@ -306,25 +310,20 @@ def problem(self, request, dtype):
306310 def tp_and_problem (self , request , problem , extra_tp_constructor_args , with_jax ):
307311 mode = request .param
308312
309- if mode == "native" :
310- cls = oeq .TensorProduct
311- if with_jax :
312- import openequivariance .jax .TensorProduct as jax_tp
313+ if with_jax :
314+ import openequivariance .jax .TensorProduct as jax_tp
315+ from openequivariance .jax import transpose_irreps
313316
314- cls = jax_tp
315- tp = cls (problem , ** extra_tp_constructor_args )
316- return tp , problem
317+ tp_base_cls = jax_tp
317318 else :
318- if with_jax :
319- import openequivariance .jax .TensorProduct as jax_tp
320- from openequivariance .jax import transpose_irreps
319+ from openequivariance ._torch .utils import transpose_irreps
321320
322- tp_base_cls = jax_tp
323- else :
324- from openequivariance ._torch .utils import transpose_irreps
325-
326- tp_base_cls = oeq .TensorProduct
321+ tp_base_cls = oeq .TensorProduct
327322
323+ if mode == "native" :
324+ tp = tp_base_cls (problem , ** extra_tp_constructor_args )
325+ return tp , problem
326+ else :
328327 class TransposeWrapperTensorProduct (tp_base_cls ):
329328 def forward (self , x , y , W ):
330329 x_t = transpose_irreps (
0 commit comments