|
6 | 6 | from openequivariance.benchmark.random_buffer_utils import ( |
7 | 7 | get_random_buffers_forward, |
8 | 8 | get_random_buffers_backward, |
9 | | -) |
| 9 | + get_random_buffers_double_backward) |
| 10 | + |
10 | 11 | from openequivariance.benchmark.logging_utils import getLogger, bcolors |
11 | 12 | import numpy as np |
12 | 13 | import numpy.linalg as la |
@@ -194,68 +195,29 @@ def correctness_double_backward( |
194 | 195 | global torch |
195 | 196 | import torch |
196 | 197 |
|
197 | | - in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward( |
198 | | - problem, batch_size, prng_seed |
199 | | - ) |
200 | | - rng = np.random.default_rng(seed=prng_seed * 2) |
201 | | - dummy_grad = rng.standard_normal(1)[0] |
| 198 | + in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = \ |
| 199 | + get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed) |
202 | 200 |
|
203 | 201 | if reference_implementation is None: |
204 | 202 | from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct |
205 | | - |
206 | 203 | reference_implementation = E3NNTensorProduct |
207 | 204 |
|
208 | 205 | result = {"thresh": correctness_threshold, "batch_size": batch_size} |
209 | 206 |
|
210 | 207 | tensors = [] |
211 | | - for i, impl in enumerate([test_implementation, reference_implementation]): |
| 208 | + for _, impl in enumerate([test_implementation, reference_implementation]): |
212 | 209 | tp = instantiate_implementation(impl, problem) |
213 | 210 |
|
214 | 211 | if impl == CUETensorProduct and problem.shared_weights: |
215 | 212 | weights = weights[np.newaxis, :] |
216 | 213 |
|
217 | | - weights_reordered = tp.reorder_weights_from_e3nn( |
218 | | - weights, not tp.config.shared_weights |
219 | | - ) |
220 | | - |
221 | | - in1_torch = torch.tensor(in1, device="cuda", requires_grad=True) |
222 | | - in2_torch = torch.tensor(in2, device="cuda", requires_grad=True) |
223 | | - weights_torch = torch.tensor( |
224 | | - weights_reordered, device="cuda", requires_grad=True |
225 | | - ) |
226 | | - |
227 | | - out_torch = tp.forward(in1_torch, in2_torch, weights_torch) |
228 | | - out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True) |
229 | | - |
230 | | - in1_grad, in2_grad, w_grad = torch.autograd.grad( |
231 | | - outputs=[out_torch], |
232 | | - inputs=[in1_torch, in2_torch, weights_torch], |
233 | | - grad_outputs=[out_grad], |
234 | | - create_graph=True, |
235 | | - ) |
236 | | - |
237 | | - dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad) |
238 | | - dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True) |
239 | | - |
240 | | - dummy.backward( |
241 | | - dummy_grad, |
242 | | - retain_graph=True, |
243 | | - inputs=[out_grad, in1_torch, in2_torch, weights_torch], |
244 | | - ) |
245 | | - |
246 | | - weights_grad = weights_torch.grad.detach().cpu().numpy() |
247 | | - weights_grad = tp.reorder_weights_to_e3nn( |
248 | | - weights_grad, not tp.config.shared_weights |
249 | | - ) |
250 | | - |
| 214 | + in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad) |
251 | 215 | tensors.append( |
252 | | - ( |
253 | | - out_grad.grad.detach().cpu().numpy(), |
254 | | - in1_torch.grad.detach().cpu().numpy(), |
255 | | - in2_torch.grad.detach().cpu().numpy(), |
256 | | - weights_grad, |
257 | | - ) |
258 | | - ) |
| 216 | + ( out_dgrad, |
| 217 | + in1_grad, |
| 218 | + in2_grad, |
| 219 | + weights_grad |
| 220 | + )) |
259 | 221 |
|
260 | 222 | for name, to_check, ground_truth in [ |
261 | 223 | ("output_double_grad", tensors[0][0], tensors[1][0]), |
|
0 commit comments