Skip to content

Commit f659115

Browse files
committed
Linted.
1 parent 7acad2e commit f659115

18 files changed

Lines changed: 339 additions & 173 deletions

openequivariance/openequivariance/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,23 @@ def _check_package_editable():
3131

3232
_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs"
3333

34+
3435
def extension_source_path():
3536
"""
3637
:returns: Path to the source code of the C++ extension.
3738
"""
3839
return str(Path(__file__).parent / "extension")
3940

41+
4042
if "OEQ_NOTORCH" not in os.environ or os.environ["OEQ_NOTORCH"] != "1":
4143
import torch
4244

43-
from openequivariance._torch.TensorProduct import TensorProduct
45+
from openequivariance._torch.TensorProduct import TensorProduct
4446
from openequivariance._torch.TensorProductConv import TensorProductConv
4547

46-
from openequivariance._torch.extlib import torch_ext_so_path as torch_ext_so_path_internal
48+
from openequivariance._torch.extlib import (
49+
torch_ext_so_path as torch_ext_so_path_internal,
50+
)
4751
from openequivariance.core.utils import torch_to_oeq_dtype
4852

4953
torch.serialization.add_safe_globals(
@@ -60,6 +64,7 @@ def extension_source_path():
6064
]
6165
)
6266

67+
6368
def torch_ext_so_path():
6469
"""
6570
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
@@ -70,6 +75,7 @@ def torch_ext_so_path():
7075
except NameError:
7176
return None
7277

78+
7379
jax = None
7480
try:
7581
import openequivariance_extjax
@@ -85,5 +91,5 @@ def torch_ext_so_path():
8591
"torch_to_oeq_dtype",
8692
"_check_package_editable",
8793
"torch_ext_so_path",
88-
"jax"
94+
"jax",
8995
]

openequivariance/openequivariance/_torch/E3NNConv.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
88
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
99

10+
1011
class E3NNConv(ConvolutionBase, NumpyDoubleBackwardMixinConv):
1112
def __init__(self, config, *, idx_dtype=np.int64, torch_op=True):
1213
assert torch_op
Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,97 @@
11
import torch
22

3+
34
class NumpyDoubleBackwardMixin:
4-
'''
5-
Adds a Numpy double backward method to any TensorProduct
5+
"""
6+
Adds a Numpy double backward method to any TensorProduct
67
with the forward pass defined in PyTorch and the relevant
7-
derivatives registered.
8-
'''
9-
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad):
8+
derivatives registered.
9+
"""
10+
11+
def double_backward_cpu(
12+
self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad
13+
):
1014
assert self.torch_op
1115

12-
in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True)
13-
in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True)
14-
weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True)
15-
out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True)
16-
in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda')
17-
in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda')
18-
weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda')
16+
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
17+
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
18+
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
19+
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
20+
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
21+
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
22+
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
1923
out_torch = self.forward(in1_torch, in2_torch, weights_torch)
2024

2125
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
2226
outputs=out_torch,
2327
inputs=[in1_torch, in2_torch, weights_torch],
2428
grad_outputs=out_grad_torch,
2529
create_graph=True,
26-
retain_graph=True
30+
retain_graph=True,
2731
)
2832

2933
a, b, c, d = torch.autograd.grad(
3034
outputs=[in1_grad, in2_grad, weights_grad],
3135
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
32-
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch]
36+
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
3337
)
3438

35-
return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy()
39+
return (
40+
a.detach().cpu().numpy(),
41+
b.detach().cpu().numpy(),
42+
c.detach().cpu().numpy(),
43+
d.detach().cpu().numpy(),
44+
)
3645

3746

3847
class NumpyDoubleBackwardMixinConv:
39-
'''
48+
"""
4049
Similar, but for fused graph convolution.
41-
'''
42-
def double_backward_cpu(self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph):
50+
"""
51+
52+
def double_backward_cpu(
53+
self, in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, graph
54+
):
4355
assert self.torch_op
4456

45-
in1_torch = torch.tensor(in1).to('cuda').requires_grad_(True)
46-
in2_torch = torch.tensor(in2).to('cuda').requires_grad_(True)
47-
weights_torch = torch.tensor(weights).to('cuda').requires_grad_(True)
48-
out_grad_torch = torch.tensor(out_grad).to('cuda').requires_grad_(True)
49-
in1_dgrad_torch = torch.tensor(in1_dgrad).to('cuda')
50-
in2_dgrad_torch = torch.tensor(in2_dgrad).to('cuda')
51-
weights_dgrad_torch = torch.tensor(weights_dgrad).to('cuda')
57+
in1_torch = torch.tensor(in1).to("cuda").requires_grad_(True)
58+
in2_torch = torch.tensor(in2).to("cuda").requires_grad_(True)
59+
weights_torch = torch.tensor(weights).to("cuda").requires_grad_(True)
60+
out_grad_torch = torch.tensor(out_grad).to("cuda").requires_grad_(True)
61+
in1_dgrad_torch = torch.tensor(in1_dgrad).to("cuda")
62+
in2_dgrad_torch = torch.tensor(in2_dgrad).to("cuda")
63+
weights_dgrad_torch = torch.tensor(weights_dgrad).to("cuda")
5264

5365
torch_rows = torch.tensor(graph.rows, device="cuda")
5466
torch_cols = torch.tensor(graph.cols, device="cuda")
5567
torch_transpose_perm = torch.tensor(graph.transpose_perm, device="cuda")
5668

57-
out_torch = self.forward(in1_torch, in2_torch, weights_torch, torch_rows, torch_cols, torch_transpose_perm)
69+
out_torch = self.forward(
70+
in1_torch,
71+
in2_torch,
72+
weights_torch,
73+
torch_rows,
74+
torch_cols,
75+
torch_transpose_perm,
76+
)
5877

5978
in1_grad, in2_grad, weights_grad = torch.autograd.grad(
6079
outputs=out_torch,
6180
inputs=[in1_torch, in2_torch, weights_torch],
6281
grad_outputs=out_grad_torch,
6382
create_graph=True,
64-
retain_graph=True
83+
retain_graph=True,
6584
)
6685

6786
a, b, c, d = torch.autograd.grad(
6887
outputs=[in1_grad, in2_grad, weights_grad],
6988
inputs=[in1_torch, in2_torch, weights_torch, out_grad_torch],
70-
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch]
89+
grad_outputs=[in1_dgrad_torch, in2_dgrad_torch, weights_dgrad_torch],
7190
)
7291

73-
return a.detach().cpu().numpy(), b.detach().cpu().numpy(), c.detach().cpu().numpy(), d.detach().cpu().numpy()
74-
75-
92+
return (
93+
a.detach().cpu().numpy(),
94+
b.detach().cpu().numpy(),
95+
c.detach().cpu().numpy(),
96+
d.detach().cpu().numpy(),
97+
)

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,14 @@ def __setstate__(self, state):
9292
self._init_class()
9393

9494
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
95-
return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights)
95+
return reorder_torch(
96+
self.forward_schedule, weights, "forward", not self.config.shared_weights
97+
)
9698

9799
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
98-
return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights)
100+
return reorder_torch(
101+
self.forward_schedule, weights, "backward", not self.config.shared_weights
102+
)
99103

100104
def forward(
101105
self, x: torch.Tensor, y: torch.Tensor, W: torch.Tensor
@@ -347,7 +351,7 @@ def name():
347351
return "LoopUnrollTP"
348352

349353

350-
if extlib.TORCH_COMPILE:
354+
if extlib.TORCH_COMPILE:
351355
TensorProduct.register_torch_fakes()
352356
TensorProduct.register_autograd()
353357
TensorProduct.register_autocast()

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
logger = getLogger()
2929

30+
3031
class TensorProductConv(torch.nn.Module, LoopUnrollConv, NumpyDoubleBackwardMixinConv):
3132
r"""
3233
Given a **symmetric, directed** graph :math:`G = (V, E)`, inputs :math:`x_1...x_{|V|}`,
@@ -420,10 +421,14 @@ def double_backward(ctx, grad_output):
420421
)
421422

422423
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
423-
return reorder_torch(self.forward_schedule, weights, "forward", not self.config.shared_weights)
424+
return reorder_torch(
425+
self.forward_schedule, weights, "forward", not self.config.shared_weights
426+
)
424427

425428
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
426-
return reorder_torch(self.forward_schedule, weights, "backward", not self.config.shared_weights)
429+
return reorder_torch(
430+
self.forward_schedule, weights, "backward", not self.config.shared_weights
431+
)
427432

428433
@staticmethod
429434
def name():

openequivariance/openequivariance/_torch/utils.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,53 @@
11
import torch
22

3+
34
def reorder_helper(schedule, weights_in, direction, has_batch_dim):
4-
assert direction in ["forward", "backward"]
5-
6-
specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
7-
weights_out = torch.zeros_like(weights_in)
8-
9-
for spec in specs:
10-
parent_range = spec["parent_range"]
11-
parent_shape = spec["parent_shape"]
12-
weights_subrange = spec["weights_subrange"]
13-
child_range = spec["child_range"]
14-
transpose_perm = spec["transpose_perm"]
15-
16-
if direction == "forward":
17-
reshape_size = spec["reshape_size"]
18-
19-
sliced_weights = weights_in[parent_range].reshape(parent_shape)[
20-
weights_subrange
21-
]
22-
23-
weights_out[child_range] = sliced_weights.permute(
24-
transpose_perm
25-
).reshape(reshape_size)
26-
27-
elif direction == "backward":
28-
transpose_child_shape = spec["transpose_child_shape"]
29-
child_shape = spec["child_shape"]
30-
31-
sliced_weights = (
32-
weights_in[child_range]
33-
.reshape(transpose_child_shape)
34-
.permute(transpose_perm)
35-
)
36-
37-
weights_out[parent_range].reshape(parent_shape)[
38-
weights_subrange
39-
] = sliced_weights.flatten().reshape(child_shape)
40-
41-
return weights_out
5+
assert direction in ["forward", "backward"]
6+
7+
specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
8+
weights_out = torch.zeros_like(weights_in)
9+
10+
for spec in specs:
11+
parent_range = spec["parent_range"]
12+
parent_shape = spec["parent_shape"]
13+
weights_subrange = spec["weights_subrange"]
14+
child_range = spec["child_range"]
15+
transpose_perm = spec["transpose_perm"]
16+
17+
if direction == "forward":
18+
reshape_size = spec["reshape_size"]
19+
20+
sliced_weights = weights_in[parent_range].reshape(parent_shape)[
21+
weights_subrange
22+
]
23+
24+
weights_out[child_range] = sliced_weights.permute(transpose_perm).reshape(
25+
reshape_size
26+
)
27+
28+
elif direction == "backward":
29+
transpose_child_shape = spec["transpose_child_shape"]
30+
child_shape = spec["child_shape"]
31+
32+
sliced_weights = (
33+
weights_in[child_range]
34+
.reshape(transpose_child_shape)
35+
.permute(transpose_perm)
36+
)
37+
38+
weights_out[parent_range].reshape(parent_shape)[weights_subrange] = (
39+
sliced_weights.flatten().reshape(child_shape)
40+
)
41+
42+
return weights_out
43+
4244

4345
def reorder_numpy_helper(schedule, weights_in, direction, has_batch_dim):
4446
weights_in = torch.from_numpy(weights_in.copy())
4547
result = reorder_helper(schedule, weights_in, direction, has_batch_dim)
4648
return result.detach().cpu().numpy().copy()
4749

50+
4851
def reorder_torch(schedule, weights_in, direction, has_batch_dim):
4952
if isinstance(weights_in, torch.Tensor):
5053
return reorder_helper(schedule, weights_in, direction, has_batch_dim)

openequivariance/openequivariance/benchmark/correctness_utils.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from openequivariance.benchmark.random_buffer_utils import (
77
get_random_buffers_forward,
88
get_random_buffers_backward,
9-
get_random_buffers_double_backward)
9+
get_random_buffers_double_backward,
10+
)
1011

1112
from openequivariance.benchmark.logging_utils import getLogger, bcolors
1213
import numpy as np
@@ -195,31 +196,51 @@ def correctness_double_backward(
195196
global torch
196197
import torch
197198

198-
in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = \
199-
get_random_buffers_double_backward(problem, batch_size=batch_size, prng_seed=prng_seed)
199+
in1, in2, out_grad, weights, weights_dgrad, in1_dgrad, in2_dgrad, _ = (
200+
get_random_buffers_double_backward(
201+
problem, batch_size=batch_size, prng_seed=prng_seed
202+
)
203+
)
200204

201205
if reference_implementation is None:
202206
from openequivariance._torch.E3NNTensorProduct import E3NNTensorProduct
207+
203208
reference_implementation = E3NNTensorProduct
204209

205210
result = {"thresh": correctness_threshold, "batch_size": batch_size}
206211

207212
tensors = []
208213
for _, impl in enumerate([test_implementation, reference_implementation]):
209214
tp = instantiate_implementation(impl, problem)
210-
weights_reordered = tp.reorder_weights_from_e3nn(weights, has_batch_dim=not problem.shared_weights)
211-
weights_dgrad_reordered = tp.reorder_weights_from_e3nn(weights_dgrad, has_batch_dim=not problem.shared_weights)
215+
weights_reordered = tp.reorder_weights_from_e3nn(
216+
weights, has_batch_dim=not problem.shared_weights
217+
)
218+
weights_dgrad_reordered = tp.reorder_weights_from_e3nn(
219+
weights_dgrad, has_batch_dim=not problem.shared_weights
220+
)
212221

213222
if impl == CUETensorProduct and problem.shared_weights:
214223
weights_reordered = weights_reordered[np.newaxis, :]
215224

216-
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(in1, in2, out_grad, weights_reordered, weights_dgrad_reordered, in1_dgrad, in2_dgrad)
225+
in1_grad, in2_grad, weights_grad, out_dgrad = tp.double_backward_cpu(
226+
in1,
227+
in2,
228+
out_grad,
229+
weights_reordered,
230+
weights_dgrad_reordered,
231+
in1_dgrad,
232+
in2_dgrad,
233+
)
217234
tensors.append(
218-
( out_dgrad,
235+
(
236+
out_dgrad,
219237
in1_grad,
220238
in2_grad,
221-
tp.reorder_weights_to_e3nn(weights_grad, has_batch_dim=not problem.shared_weights)
222-
))
239+
tp.reorder_weights_to_e3nn(
240+
weights_grad, has_batch_dim=not problem.shared_weights
241+
),
242+
)
243+
)
223244

224245
for name, to_check, ground_truth in [
225246
("output_double_grad", tensors[0][0], tensors[1][0]),

0 commit comments

Comments
 (0)