@@ -273,8 +273,8 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
273273 return tp , tp .config
274274
275275
276- class TestIrMulLayoutMACE (TPCorrectness ):
277- production_model_tpps = mace_problems () + [
276+ class TestIrMul (TPCorrectness ):
277+ tpps = mace_problems () + [
278278 oeq .TPProblem (
279279 "5x5e" ,
280280 "1x3e" ,
@@ -295,13 +295,57 @@ class TestIrMulLayoutMACE(TPCorrectness):
295295 ),
296296 ]
297297
298- @pytest .fixture (params = production_model_tpps , ids = lambda x : x .label , scope = "class" )
298+ @pytest .fixture (params = tpps , ids = lambda x : x .label , scope = "class" )
299299 def problem (self , request , dtype ):
300300 problem = request .param .clone ()
301301 problem .irrep_dtype , problem .weight_dtype = dtype , dtype
302302 problem .layout = "ir_mul"
303303 return problem
304304
305+ @pytest .fixture (params = ["native" , "transpose_wrapper" ], scope = "class" )
306+ def tp_and_problem (self , request , problem , extra_tp_constructor_args , with_jax ):
307+ mode = request .param
308+
309+ if mode == "native" :
310+ cls = oeq .TensorProduct
311+ if with_jax :
312+ import openequivariance .jax .TensorProduct as jax_tp
313+
314+ cls = jax_tp
315+ tp = cls (problem , ** extra_tp_constructor_args )
316+ return tp , problem
317+ else :
318+ if with_jax :
319+ import openequivariance .jax .TensorProduct as jax_tp
320+ from openequivariance .jax import transpose_irreps
321+
322+ tp_base_cls = jax_tp
323+ else :
324+ from openequivariance ._torch .utils import transpose_irreps
325+
326+ tp_base_cls = oeq .TensorProduct
327+
328+ class TransposeWrapperTensorProduct (tp_base_cls ):
329+ def forward (self , x , y , W ):
330+ x_t = transpose_irreps (
331+ x , self .config .irreps_in1 , "ir_mul" , "mul_ir"
332+ )
333+ y_t = transpose_irreps (
334+ y , self .config .irreps_in2 , "ir_mul" , "mul_ir"
335+ )
336+ out_mul_ir = super ().forward (x_t , y_t , W )
337+ return transpose_irreps (
338+ out_mul_ir , self .config .irreps_out , "mul_ir" , "ir_mul"
339+ )
340+
341+ wrapped_problem = problem .clone ()
342+ wrapped_problem .layout = "mul_ir"
343+ tp = TransposeWrapperTensorProduct (
344+ wrapped_problem , ** extra_tp_constructor_args
345+ )
346+ tp .config .layout = "ir_mul"
347+ return tp , problem
348+
305349
306350class TestTorchToSubmodule :
307351 """Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
0 commit comments