88 correctness_double_backward ,
99 correctness_forward ,
1010)
11+ from openequivariance .benchmark .test_buffers import get_random_buffers_forward
1112from openequivariance .benchmark .problems import (
1213 diffdock_problems ,
1314 e3nn_torch_tetris_poly_problems ,
@@ -274,9 +275,10 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
274275
275276
276277class TestIrMul (TPCorrectness ):
277- '''
278- Tests both the ir_mul layout and the transpose_irreps functions.
279- '''
278+ """
279+ Tests both the ir_mul layout and the transpose_irreps functions.
280+ """
281+
280282 tpps = mace_problems () + [
281283 oeq .TPProblem (
282284 "5x5e" ,
@@ -323,6 +325,7 @@ def tp_and_problem(self, request, problem, extra_tp_constructor_args, with_jax):
323325 tp = tp_base_cls (problem , ** extra_tp_constructor_args )
324326 return tp , problem
325327 else :
328+
326329 class TransposeWrapperTensorProduct (tp_base_cls ):
327330 def forward (self , x , y , W ):
328331 x_t = transpose_irreps (
@@ -370,40 +373,29 @@ def forward(self, x, y, w):
370373 def _problem_dtype (self , problem ):
371374 return torch .float32 if problem .irrep_dtype == np .float32 else torch .float64
372375
373- def _make_inputs (self , problem , batch_size , rng , dtype , device ):
374- in1 = torch .tensor (
375- rng .uniform (size = (batch_size , problem .irreps_in1 .dim )),
376- dtype = dtype ,
377- device = device ,
378- )
379- in2 = torch .tensor (
380- rng .uniform (size = (batch_size , problem .irreps_in2 .dim )),
381- dtype = dtype ,
382- device = device ,
383- )
384- weights_size = (
385- (problem .weight_numel ,)
386- if problem .shared_weights
387- else (batch_size , problem .weight_numel )
388- )
389- weights = torch .tensor (
390- rng .uniform (size = weights_size ),
391- dtype = dtype ,
392- device = device ,
376+ def _make_inputs (self , problem , batch_size , dtype , device , prng_seed = 12345 ):
377+ dtype_map = {torch .float32 : np .float32 , torch .float64 : np .float64 }
378+ buffer_problem = problem .clone ()
379+ buffer_problem .irrep_dtype = dtype_map [dtype ]
380+ buffer_problem .weight_dtype = dtype_map [dtype ]
381+
382+ in1_np , in2_np , weights_np , _ = get_random_buffers_forward (
383+ buffer_problem , batch_size = batch_size , prng_seed = prng_seed
393384 )
394- return in1 , in2 , weights
385+
386+ return [
387+ torch .tensor (arr , dtype = dtype , device = device )
388+ for arr in [in1_np , in2_np , weights_np ]
389+ ]
395390
396391 def test_submodule_dtype_conversion (self , parent_module_and_problem ):
397392 """Test that calling .to() on parent module properly converts TensorProduct submodule"""
398393 parent , problem = parent_module_and_problem
399394
400395 batch_size = 10
401- rng = np .random .default_rng (12345 )
402396 device = "cuda"
403397 input_dtype = self ._problem_dtype (problem )
404- in1 , in2 , weights = self ._make_inputs (
405- problem , batch_size , rng , input_dtype , device
406- )
398+ in1 , in2 , weights = self ._make_inputs (problem , batch_size , input_dtype , device )
407399
408400 output1 = parent (in1 , in2 , weights )
409401 assert output1 .dtype == in1 .dtype , (
@@ -418,7 +410,7 @@ def test_submodule_dtype_conversion(self, parent_module_and_problem):
418410 parent .to (target_dtype )
419411
420412 in1_new , in2_new , weights_new = self ._make_inputs (
421- problem , batch_size , rng , target_dtype , device
413+ problem , batch_size , target_dtype , device , prng_seed = 23456
422414 )
423415
424416 output2 = parent (in1_new , in2_new , weights_new )
0 commit comments