@@ -272,9 +272,9 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
272272 tp .to (switch_map [problem .irrep_dtype ])
273273 return tp , tp .config
274274
275-
276- def ir_mul_representative_uvu_problems ():
277- return [
275+ class TestIrMulLayoutMACE ( TPCorrectness ):
276+ production_model_tpps = mace_problems () + \
277+ [
278278 oeq .TPProblem (
279279 "5x5e" ,
280280 "1x3e" ,
@@ -293,11 +293,7 @@ def ir_mul_representative_uvu_problems():
293293 internal_weights = False ,
294294 label = "ir_mul_repr_13x1x13_l535" ,
295295 ),
296- ]
297-
298-
299- class TestIrMulLayoutMACE (TPCorrectness ):
300- production_model_tpps = mace_problems () + ir_mul_representative_uvu_problems ()
296+ ]
301297
302298 @pytest .fixture (params = production_model_tpps , ids = lambda x : x .label , scope = "class" )
303299 def problem (self , request , dtype ):
@@ -307,23 +303,6 @@ def problem(self, request, dtype):
307303 return problem
308304
309305
310- def test_ir_mul_rejects_uvw_problem (dtype ):
311- problem = oeq .TPProblem (
312- "5x5e" ,
313- "1x3e" ,
314- "5x5e" ,
315- [(0 , 0 , 0 , "uvw" , True )],
316- shared_weights = False ,
317- internal_weights = False ,
318- irrep_dtype = dtype ,
319- weight_dtype = dtype ,
320- layout = "ir_mul" ,
321- )
322-
323- with pytest .raises (AssertionError , match = "layout='ir_mul'" ):
324- oeq .TensorProduct (problem )
325-
326-
327306class TestTorchToSubmodule :
328307 """Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
329308
0 commit comments