33from openequivariance .benchmark .random_buffer_utils import (
44 get_random_buffers_forward_conv ,
55 get_random_buffers_backward_conv ,
6+ get_random_buffers_double_backward_conv ,
67)
78
89from openequivariance .benchmark .logging_utils import getLogger , bcolors
1314
1415logger = getLogger ()
1516
16-
1717def flops_data_per_tp (config , direction ):
1818 """
1919 Assumes all interactions are "uvu" for now
@@ -549,20 +549,12 @@ def test_correctness_double_backward(
549549 reference_implementation = None ,
550550 high_precision_ref = False ,
551551 ):
552- global torch
553- import torch
554-
555- assert self .torch_op
556- buffers = get_random_buffers_backward_conv (
557- self .config , graph .node_count , graph .nnz , prng_seed
558- )
559-
560- rng = np .random .default_rng (seed = prng_seed * 2 )
561- dummy_grad_value = rng .standard_normal (1 )[0 ]
552+ buffers = get_random_buffers_double_backward_conv (
553+ self .config , graph .node_count , graph .nnz , prng_seed
554+ )
562555
563556 if reference_implementation is None :
564557 from openequivariance .impl_torch .E3NNConv import E3NNConv
565-
566558 reference_implementation = E3NNConv
567559
568560 reference_problem = self .config
@@ -576,63 +568,30 @@ def test_correctness_double_backward(
576568 result = {"thresh" : thresh }
577569 tensors = []
578570 for i , tp in enumerate ([self , reference_tp ]):
579- in1 , in2 , out_grad , weights , _ , _ , _ = [buf .copy () for buf in buffers ]
571+ buffers_copy = [buf .copy () for buf in buffers ]
580572
581573 if i == 1 and high_precision_ref :
582- in1 , in2 , out_grad , weights , _ , _ , _ = [
574+ buffers_copy = [
583575 np .array (el , dtype = np .float64 ) for el in buffers
584576 ]
585577
586- in1_torch = torch .tensor (in1 , device = "cuda" , requires_grad = True )
587- in2_torch = torch .tensor (in2 , device = "cuda" , requires_grad = True )
578+ in1 , in2 , out_grad , weights , weights_dgrad , in1_dgrad , in2_dgrad , _ = buffers_copy
588579
589580 weights_reordered = tp .reorder_weights_from_e3nn (
590581 weights , not self .config .shared_weights
591582 )
592-
593- weights_torch = torch .tensor (
594- weights_reordered , device = "cuda" , requires_grad = True
583+ weights_dgrad_reordered = tp .reorder_weights_from_e3nn (
584+ weights_dgrad , not self .config .shared_weights
595585 )
596586
597- torch_rows = torch .tensor (graph .rows , device = "cuda" )
598- torch_cols = torch .tensor (graph .cols , device = "cuda" )
599- torch_transpose_perm = torch .tensor (graph .transpose_perm , device = "cuda" )
600-
601- fwd_args = [in1_torch , in2_torch , weights_torch , torch_rows , torch_cols ]
602- if tp .deterministic :
603- fwd_args .append (torch_transpose_perm )
604-
605- out_torch = tp .forward (* fwd_args )
606- out_grad_torch = torch .tensor (out_grad , device = "cuda" , requires_grad = True )
607-
608- in1_grad , in2_grad , w_grad = torch .autograd .grad (
609- outputs = [out_torch ],
610- inputs = [in1_torch , in2_torch , weights_torch ],
611- grad_outputs = [out_grad_torch ],
612- create_graph = True ,
613- )
614-
615- dummy = torch .norm (in1_grad ) + torch .norm (in2_grad ) + torch .norm (w_grad )
616- dummy_grad = torch .tensor (
617- float (dummy_grad_value ), device = "cuda" , requires_grad = True
618- )
619- dummy .backward (
620- dummy_grad , inputs = [out_grad_torch , in1_torch , in2_torch , weights_torch ]
621- )
622-
623- weights_grad = weights_torch .grad .detach ().cpu ().numpy ()
624- weights_grad = tp .reorder_weights_to_e3nn (
625- weights_grad , not self .config .shared_weights
626- )
587+ in1_grad , in2_grad , weights_grad , out_dgrad = self .double_backward_cpu (in1 , in2 , out_grad , weights_reordered , weights_dgrad_reordered , in1_dgrad , in2_dgrad , graph )
627588
628589 tensors .append (
629- (
630- out_grad_torch .grad .detach ().cpu ().numpy ().copy (),
631- in1_torch .grad .detach ().cpu ().numpy ().copy (),
632- in2_torch .grad .detach ().cpu ().numpy ().copy (),
633- weights_grad .copy (),
634- )
635- )
590+ ( out_dgrad ,
591+ in1_grad ,
592+ in2_grad ,
593+ self .reorder_weights_to_e3nn (weights_grad , has_batch_dim = not self .config .shared_weights )
594+ ))
636595
637596 for name , to_check , ground_truth in [
638597 ("output_grad" , tensors [0 ][0 ], tensors [1 ][0 ]),
0 commit comments