Skip to content

Commit 427fdcb

Browse files
committed
Ran ruff.
1 parent a1b6248 commit 427fdcb

13 files changed

Lines changed: 184 additions & 97 deletions

File tree

docs/conf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,10 @@
3434

3535
sys.path.insert(0, str(Path("..").resolve()))
3636

37-
autodoc_mock_imports = ["torch", "openequivariance.impl_torch.extlib", "jinja2", "numpy"]
37+
autodoc_mock_imports = [
38+
"torch",
39+
"openequivariance.impl_torch.extlib",
40+
"jinja2",
41+
"numpy",
42+
]
3843
autodoc_typehints = "description"

openequivariance/openequivariance/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from openequivariance.impl_torch.TensorProductConv import (
1818
TensorProductConv,
1919
)
20+
from openequivariance.impl_torch.extlib import torch_ext_so_path
2021
from openequivariance.core.utils import torch_to_oeq_dtype
2122

2223
__version__ = None
@@ -37,20 +38,13 @@ def _check_package_editable():
3738
_editable_install_output_path = Path(__file__).parent.parent.parent / "outputs"
3839

3940

40-
def torch_ext_so_path():
41-
"""
42-
:returns: Path to a ``.so`` file that must be linked to use OpenEquivariance
43-
from the PyTorch C++ Interface.
44-
"""
45-
return openequivariance.impl_torch.extlib.torch_module.__file__
46-
47-
4841
def extension_source_path():
4942
"""
5043
:returns: Path to the source code of the C++ extension.
5144
"""
5245
return str(Path(__file__).parent / "extension")
5346

47+
5448
torch.serialization.add_safe_globals(
5549
[
5650
TensorProduct,

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,7 @@ def test_correctness_double_backward(
562562

563563
if reference_implementation is None:
564564
from openequivariance.impl_torch.E3NNConv import E3NNConv
565+
565566
reference_implementation = E3NNConv
566567

567568
reference_problem = self.config

openequivariance/openequivariance/core/LoopUnrollConv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from openequivariance.templates.jinja_utils import get_jinja_environment
1111
from openequivariance.core.utils import filter_and_analyze_problem
1212

13+
1314
class LoopUnrollConv(ConvolutionBase):
1415
def __init__(
1516
self,
1617
config,
17-
dp, postprocess_kernel,
18+
dp,
19+
postprocess_kernel,
1820
*,
1921
idx_dtype: type[np.generic] = np.int64,
2022
torch_op: bool = False,

openequivariance/openequivariance/core/LoopUnrollTP.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
count_cg_non_zero,
1111
)
1212

13+
1314
class LoopUnrollTP(TensorProductBase):
1415
def __init__(self, config, dp, postprocess_kernel, torch_op):
1516
super().__init__(config, torch_op=torch_op)
@@ -99,14 +100,12 @@ def generate_double_backward_schedule(warps_per_block):
99100
"opt_level": 3,
100101
"irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
101102
"weight_dtype": dtype_to_enum[self.config.weight_dtype],
102-
103-
# Not relevant, included for compatibility with convolution
103+
# Not relevant, included for compatibility with convolution
104104
"workspace_size": 0,
105105
"deterministic": 1,
106-
"idx_dtype": 0
106+
"idx_dtype": 0,
107107
}
108108

109-
110109
def calculate_flops_forward(self, batch_size: int) -> dict:
111110
if self.is_uvw:
112111
return super().calculate_flops_forward(batch_size)

openequivariance/openequivariance/core/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import hashlib
1111
from openequivariance.impl_torch.extlib import GPUTimer
1212

13+
1314
def sparse_outer_product_work(cg: np.ndarray) -> int:
1415
return np.sum(np.max(cg != 0, axis=2))
1516

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,62 @@
1-
import numpy as np
2-
31
import jax
4-
52
from functools import partial
63
from openequivariance.impl_jax import extlib
7-
import hashlib
8-
from openequivariance.core.e3nn_lite import TPProblem, Irreps
4+
from openequivariance.core.e3nn_lite import TPProblem
95
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
106
from openequivariance.core.utils import hash_attributes
11-
import jax.numpy as jnp
127

13-
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5))
8+
9+
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
1410
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
15-
forward_call = jax.ffi.ffi_call("tp_forward",
16-
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
11+
forward_call = jax.ffi.ffi_call(
12+
"tp_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)
13+
)
1714
return forward_call(X, Y, W, **attrs)
1815

16+
1917
def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
2018
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
2119

22-
@partial(jax.custom_vjp, nondiff_argnums=(4,5))
20+
21+
@partial(jax.custom_vjp, nondiff_argnums=(4, 5))
2322
def backward(X, Y, W, dZ, irrep_dtype, attrs):
24-
backward_call = jax.ffi.ffi_call("tp_backward",
23+
backward_call = jax.ffi.ffi_call(
24+
"tp_backward",
2525
(
2626
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
2727
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
2828
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
29-
))
29+
),
30+
)
3031

3132
return backward_call(X, Y, W, dZ, **attrs)
3233

34+
3335
def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
3436
return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)
3537

38+
3639
def double_backward(irrep_dtype, attrs, inputs, derivatives):
37-
double_backward_call = jax.ffi.ffi_call("tp_double_backward",
40+
double_backward_call = jax.ffi.ffi_call(
41+
"tp_double_backward",
3842
(
3943
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
4044
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
4145
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
4246
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
43-
))
47+
),
48+
)
4449
return double_backward_call(*inputs, *derivatives, **attrs)
4550

51+
4652
def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
47-
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
53+
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
54+
4855

4956
forward.defvjp(forward_with_inputs, backward_autograd)
5057
backward.defvjp(backward_with_inputs, double_backward)
5158

59+
5260
class TensorProduct(LoopUnrollTP):
5361
def __init__(self, config: TPProblem):
5462
dp = extlib.DeviceProp(0)
@@ -59,18 +67,17 @@ def __init__(self, config: TPProblem):
5967
"forward_config": vars(self.forward_schedule.launch_config),
6068
"backward_config": vars(self.backward_schedule.launch_config),
6169
"double_backward_config": vars(self.double_backward_schedule.launch_config),
62-
"kernel_prop": self.kernelProp
70+
"kernel_prop": self.kernelProp,
6371
}
6472
hash_attributes(self.attrs)
65-
73+
6674
self.weight_numel = config.weight_numel
6775
self.L3_dim = self.config.irreps_out.dim
6876

6977
def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray:
7078
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
7179

72-
def __call__(self,
73-
X: jax.numpy.ndarray,
74-
Y: jax.numpy.ndarray,
75-
W: jax.numpy.ndarray) -> jax.numpy.ndarray:
76-
return self.forward(X, Y, W)
80+
def __call__(
81+
self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray
82+
) -> jax.numpy.ndarray:
83+
return self.forward(X, Y, W)

0 commit comments

Comments
 (0)