1919import torch
2020
2121
22+ @pytest .fixture (params = [np .float32 , np .float64 ], ids = ["F32" , "F64" ], scope = "module" )
23+ def dtype (request ):
24+ return request .param
25+
26+
2227class TPCorrectness :
2328 def thresh (self , direction ):
2429 return {"fwd" : 1e-5 , "bwd" : 3e-4 , "double_bwd" : 3e-4 }[direction ]
@@ -31,18 +36,10 @@ def check_result(self, result, fieldname):
3136 f"{ fieldname } observed error={ error :.5f} >= { thresh } "
3237 )
3338
34- @pytest .fixture (params = [np .float32 , np .float64 ], ids = ["F32" , "F64" ], scope = "class" )
35- def dtype (self , request ):
36- return request .param
37-
3839 @pytest .fixture (scope = "class" )
3940 def extra_tp_constructor_args (self ):
4041 return {}
4142
42- @pytest .fixture (scope = "class" )
43- def with_jax (self , request ):
44- return request .config .getoption ("--jax" )
45-
4643 @pytest .fixture (scope = "class" )
4744 def tp_and_problem (self , problem , extra_tp_constructor_args , with_jax ):
4845 cls = oeq .TensorProduct
@@ -274,3 +271,85 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274271 }
275272 tp .to (switch_map [problem .irrep_dtype ])
276273 return tp , tp .config
274+
275+
276+ class TestTorchToSubmodule :
277+ """Test that TensorProduct works correctly as a submodule when parent's .to() is called"""
278+
279+ @pytest .fixture (scope = "class" )
280+ def parent_module_and_problem (self , dtype , with_jax ):
281+ if with_jax :
282+ pytest .skip ("N/A for JAX" )
283+
284+ problem = mace_problems ()[0 ].clone ()
285+ problem .irrep_dtype , problem .weight_dtype = dtype , dtype
286+
287+ class ParentModule (torch .nn .Module ):
288+ def __init__ (self , problem ):
289+ super ().__init__ ()
290+ self .tp = oeq .TensorProduct (problem )
291+
292+ def forward (self , x , y , w ):
293+ return self .tp (x , y , w )
294+
295+ parent = ParentModule (problem )
296+ return parent , problem
297+
298+ def _problem_dtype (self , problem ):
299+ return torch .float32 if problem .irrep_dtype == np .float32 else torch .float64
300+
301+ def _make_inputs (self , problem , batch_size , rng , dtype , device ):
302+ in1 = torch .tensor (
303+ rng .uniform (size = (batch_size , problem .irreps_in1 .dim )),
304+ dtype = dtype ,
305+ device = device ,
306+ )
307+ in2 = torch .tensor (
308+ rng .uniform (size = (batch_size , problem .irreps_in2 .dim )),
309+ dtype = dtype ,
310+ device = device ,
311+ )
312+ weights_size = (
313+ (problem .weight_numel ,)
314+ if problem .shared_weights
315+ else (batch_size , problem .weight_numel )
316+ )
317+ weights = torch .tensor (
318+ rng .uniform (size = weights_size ),
319+ dtype = dtype ,
320+ device = device ,
321+ )
322+ return in1 , in2 , weights
323+
324+ def test_submodule_dtype_conversion (self , parent_module_and_problem ):
325+ """Test that calling .to() on parent module properly converts TensorProduct submodule"""
326+ parent , problem = parent_module_and_problem
327+
328+ batch_size = 10
329+ rng = np .random .default_rng (12345 )
330+ device = "cuda"
331+ input_dtype = self ._problem_dtype (problem )
332+ in1 , in2 , weights = self ._make_inputs (
333+ problem , batch_size , rng , input_dtype , device
334+ )
335+
336+ output1 = parent (in1 , in2 , weights )
337+ assert output1 .dtype == in1 .dtype , (
338+ f"Expected output dtype { in1 .dtype } , got { output1 .dtype } "
339+ )
340+
341+ switch_map = {
342+ np .float32 : torch .float64 ,
343+ np .float64 : torch .float32 ,
344+ }
345+ target_dtype = switch_map [problem .irrep_dtype ]
346+ parent .to (target_dtype )
347+
348+ in1_new , in2_new , weights_new = self ._make_inputs (
349+ problem , batch_size , rng , target_dtype , device
350+ )
351+
352+ output2 = parent (in1_new , in2_new , weights_new )
353+ assert output2 .dtype == target_dtype , (
354+ f"Expected output dtype { target_dtype } , got { output2 .dtype } "
355+ )
0 commit comments