Skip to content

Commit 4dc31dc

Browse files
committed
Correctness double backward works for existing code, need to extend to JAX.
1 parent 58b7957 commit 4dc31dc

2 files changed

Lines changed: 13 additions & 50 deletions

File tree

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from openequivariance.benchmark.random_buffer_utils import (
77
get_random_buffers_forward,
88
get_random_buffers_backward,
9-
)
9+
get_random_buffers_double_backward)
10+
1011
from openequivariance.benchmark.logging_utils import getLogger, bcolors
1112
import numpy as np
1213
import numpy.linalg as la
@@ -194,68 +195,29 @@ def correctness_double_backward(
194195
global torch
195196
import torch
196197

197-
in1, in2, out_grad, weights, _, _, _ = get_random_buffers_backward(
198-
problem, batch_size, prng_seed
199-
)
200-
rng = np.random.default_rng(seed=prng_seed * 2)
201-
dummy_grad = rng.standard_normal(1)[0]
198+
in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = \
199+
get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed)
202200

203201
if reference_implementation is None:
204202
from openequivariance.impl_torch.E3NNTensorProduct import E3NNTensorProduct
205-
206203
reference_implementation = E3NNTensorProduct
207204

208205
result = {"thresh": correctness_threshold, "batch_size": batch_size}
209206

210207
tensors = []
211-
for i, impl in enumerate([test_implementation, reference_implementation]):
208+
for _, impl in enumerate([test_implementation, reference_implementation]):
212209
tp = instantiate_implementation(impl, problem)
213210

214211
if impl == CUETensorProduct and problem.shared_weights:
215212
weights = weights[np.newaxis, :]
216213

217-
weights_reordered = tp.reorder_weights_from_e3nn(
218-
weights, not tp.config.shared_weights
219-
)
220-
221-
in1_torch = torch.tensor(in1, device="cuda", requires_grad=True)
222-
in2_torch = torch.tensor(in2, device="cuda", requires_grad=True)
223-
weights_torch = torch.tensor(
224-
weights_reordered, device="cuda", requires_grad=True
225-
)
226-
227-
out_torch = tp.forward(in1_torch, in2_torch, weights_torch)
228-
out_grad = out_torch.clone().detach().to(device="cuda").requires_grad_(True)
229-
230-
in1_grad, in2_grad, w_grad = torch.autograd.grad(
231-
outputs=[out_torch],
232-
inputs=[in1_torch, in2_torch, weights_torch],
233-
grad_outputs=[out_grad],
234-
create_graph=True,
235-
)
236-
237-
dummy = torch.norm(in1_grad) + torch.norm(in2_grad) + torch.norm(w_grad)
238-
dummy_grad = torch.tensor(float(dummy_grad), device="cuda", requires_grad=True)
239-
240-
dummy.backward(
241-
dummy_grad,
242-
retain_graph=True,
243-
inputs=[out_grad, in1_torch, in2_torch, weights_torch],
244-
)
245-
246-
weights_grad = weights_torch.grad.detach().cpu().numpy()
247-
weights_grad = tp.reorder_weights_to_e3nn(
248-
weights_grad, not tp.config.shared_weights
249-
)
250-
214+
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad)
251215
tensors.append(
252-
(
253-
out_grad.grad.detach().cpu().numpy(),
254-
in1_torch.grad.detach().cpu().numpy(),
255-
in2_torch.grad.detach().cpu().numpy(),
256-
weights_grad,
257-
)
258-
)
216+
( out_dgrad,
217+
in1_grad,
218+
in2_grad,
219+
weights_grad
220+
))
259221

260222
for name, to_check, ground_truth in [
261223
("output_double_grad", tensors[0][0], tensors[1][0]),

openequivariance/openequivariance/impl_torch/E3NNTensorProduct.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
from openequivariance.core.TensorProductBase import TensorProductBase
1313
from openequivariance.core.e3nn_lite import TPProblem
1414
from openequivariance.benchmark.logging_utils import getLogger
15+
from openequivariance.impl_torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
1516

1617
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")
1718

1819
logger = getLogger()
1920

2021

21-
class E3NNTensorProduct(TensorProductBase):
22+
class E3NNTensorProduct(TensorProductBase, NumpyDoubleBackwardMixin):
2223
def __init__(self, config: TPProblem, torch_op=True):
2324
super().__init__(config, torch_op=torch_op)
2425
assert self.torch_op

0 commit comments

Comments
 (0)