Skip to content

Commit d1f20a1

Browse files
asgloverAustin Glover
andauthored
Ruff for Linting and Formatting (#104)
* adding pre-commit and linting / formatting via ruff * linting + remove transpose permutation in forward * readme tutorial formatted. Can exclude if preferred * ruff format and lint * only test deterministic if shared_weights != true * revert Tensor -> tensor change (Tensor is a type, tensor the constructor) * add ruff and pre-commit to the CI requirements .txt (exact version to promote caching) * skips and looser thresholds * save failed tensors * remove tensor saving * make E741 (lowercase L) lint a package level preference * add pre-commit github action * lint and format --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 9935d8d commit d1f20a1

50 files changed

Lines changed: 5245 additions & 2831 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/pre-commit.yaml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Pre-Commit Checks
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches: [main]
7+
8+
jobs:
9+
pre-commit:
10+
runs-on: ubuntu-latest
11+
steps:
12+
- uses: actions/checkout@v4
13+
- uses: pre-commit/action@v3.0.1
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
numpy==2.2.5
22
torch==2.7.0 --index-url https://download.pytorch.org/whl/cu128
33
pytest==8.3.5
4-
ninja==1.11.1.4
4+
ninja==1.11.1.4
5+
ruff==0.11.11
6+
pre-commit==4.2.0

.pre-commit-config.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
repos:
2+
- repo: https://github.com/astral-sh/ruff-pre-commit
3+
# Ruff version.
4+
rev: v0.11.11
5+
hooks:
6+
# Run the linter.
7+
- id: ruff-check
8+
# Run the formatter.
9+
- id: ruff-format

examples/readme_tutorial.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
1-
# Examples from the README
1+
# ruff: noqa: E402
2+
# Examples from the README
23
import logging
34
from openequivariance.benchmark.logging_utils import getLogger
5+
46
logger = getLogger()
57
logger.setLevel(logging.ERROR)
68

7-
# UVU Tensor Product
9+
# UVU Tensor Product
810
# ===============================
911
import torch
1012
import e3nn.o3 as o3
1113

12-
gen = torch.Generator(device='cuda')
14+
gen = torch.Generator(device="cuda")
1315

1416
batch_size = 1000
15-
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
16-
X = torch.rand(batch_size, X_ir.dim, device='cuda', generator=gen)
17-
Y = torch.rand(batch_size, Y_ir.dim, device='cuda', generator=gen)
17+
X_ir, Y_ir, Z_ir = o3.Irreps("1x2e"), o3.Irreps("1x3e"), o3.Irreps("1x2e")
18+
X = torch.rand(batch_size, X_ir.dim, device="cuda", generator=gen)
19+
Y = torch.rand(batch_size, Y_ir.dim, device="cuda", generator=gen)
1820

19-
instructions=[(0, 0, 0, "uvu", True)]
21+
instructions = [(0, 0, 0, "uvu", True)]
2022

21-
tp_e3nn = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions,
22-
shared_weights=False, internal_weights=False).to('cuda')
23-
W = torch.rand(batch_size, tp_e3nn.weight_numel, device='cuda', generator=gen)
23+
tp_e3nn = o3.TensorProduct(
24+
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
25+
).to("cuda")
26+
W = torch.rand(batch_size, tp_e3nn.weight_numel, device="cuda", generator=gen)
2427

2528
Z = tp_e3nn(X, Y, W)
2629
print(torch.norm(Z))
@@ -29,10 +32,12 @@
2932
# ===============================
3033
import openequivariance as oeq
3134

32-
problem = oeq.TPProblem(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False)
35+
problem = oeq.TPProblem(
36+
X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False
37+
)
3338
tp_fast = oeq.TensorProduct(problem, torch_op=True)
3439

35-
Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
40+
Z = tp_fast(X, Y, W) # Reuse X, Y, W from earlier
3641
print(torch.norm(Z))
3742
# ===============================
3843

@@ -44,26 +49,35 @@
4449

4550
# Receiver, sender indices for message passing GNN
4651
edge_index = EdgeIndex(
47-
[[0, 1, 1, 2], # Receiver
48-
[1, 0, 2, 1]], # Sender
49-
device='cuda',
50-
dtype=torch.long)
51-
52-
X = torch.rand(node_ct, X_ir.dim, device='cuda', generator=gen)
53-
Y = torch.rand(nonzero_ct, Y_ir.dim, device='cuda', generator=gen)
54-
W = torch.rand(nonzero_ct, problem.weight_numel, device='cuda', generator=gen)
55-
56-
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=False) # Reuse problem from earlier
57-
Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1]) # Z has shape [node_ct, z_ir.dim]
52+
[
53+
[0, 1, 1, 2], # Receiver
54+
[1, 0, 2, 1], # Sender
55+
],
56+
device="cuda",
57+
dtype=torch.long,
58+
)
59+
60+
X = torch.rand(node_ct, X_ir.dim, device="cuda", generator=gen)
61+
Y = torch.rand(nonzero_ct, Y_ir.dim, device="cuda", generator=gen)
62+
W = torch.rand(nonzero_ct, problem.weight_numel, device="cuda", generator=gen)
63+
64+
tp_conv = oeq.TensorProductConv(
65+
problem, torch_op=True, deterministic=False
66+
) # Reuse problem from earlier
67+
Z = tp_conv.forward(
68+
X, Y, W, edge_index[0], edge_index[1]
69+
) # Z has shape [node_ct, z_ir.dim]
5870
print(torch.norm(Z))
5971
# ===============================
6072

6173
# ===============================
62-
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
63-
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
74+
_, sender_perm = edge_index.sort_by("col") # Sort by sender index
75+
edge_index, receiver_perm = edge_index.sort_by("row") # Sort by receiver index
6476

6577
# Now we can use the faster deterministic algorithm
66-
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
67-
Z = tp_conv.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm)
78+
tp_conv = oeq.TensorProductConv(problem, torch_op=True, deterministic=True)
79+
Z = tp_conv.forward(
80+
X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm
81+
)
6882
print(torch.norm(Z))
69-
# ===============================
83+
# ===============================

io/cif_to_graph.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,44 @@
11
import pickle
22
import numpy as np
33
from sklearn.neighbors import radius_neighbors_graph
4-
from scipy.io import mmwrite
4+
55

66
def cif_to_molecular_graph(cif_file, cp, radii):
7-
with open(f'../data/cif_files/{cif_file}', 'r') as f:
7+
with open(f"../data/cif_files/{cif_file}", "r") as f:
88
print("Started reading file...")
99
lines = f.readlines()
1010
print("Finished reading file!")
1111

1212
coords = []
1313
for line in lines:
14-
if line.startswith('ATOM'):
14+
if line.startswith("ATOM"):
1515
parts = line.split()
16-
coords.append([float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])])
16+
coords.append(
17+
[float(parts[cp[0]]), float(parts[cp[1]]), float(parts[cp[2]])]
18+
)
1719

1820
coords = np.array(coords)
1921

2022
for radius in radii:
2123
print(f"Starting radius neighbors calculation, r={radius}")
22-
A = radius_neighbors_graph(coords, radius, mode='connectivity',
23-
include_self=False)
24-
print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.")
25-
24+
A = radius_neighbors_graph(
25+
coords, radius, mode="connectivity", include_self=False
26+
)
27+
print(f"Finished radius neighbors calculation, found {A.nnz} nonzeros.")
28+
2629
# mmwrite(f'../data/molecular_structures/{cif_file.split(".")[0]}.mtx', A)
2730

2831
coo_mat = A.tocoo()
29-
result = {
30-
'row': coo_mat.row,
31-
'col': coo_mat.col,
32-
'coords': coords
33-
}
32+
result = {"row": coo_mat.row, "col": coo_mat.col, "coords": coords}
3433

35-
with open(f'../data/molecular_structures/{cif_file.split(".")[0]}_radius{radius}.pickle', 'wb') as handle:
34+
with open(
35+
f"../data/molecular_structures/{cif_file.split('.')[0]}_radius{radius}.pickle",
36+
"wb",
37+
) as handle:
3638
pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
3739

3840

39-
if __name__=='__main__':
40-
#cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
41-
cif_to_molecular_graph('covid_spike.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
42-
cif_to_molecular_graph('1drf.cif', (10, 11, 12), radii=[6.0])
41+
if __name__ == "__main__":
42+
# cif_to_molecular_graph('hiv_capsid.cif', (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
43+
cif_to_molecular_graph("covid_spike.cif", (10, 11, 12), radii=[2.0, 2.5, 3.0, 3.5])
44+
cif_to_molecular_graph("1drf.cif", (10, 11, 12), radii=[6.0])

io/load_nequip_configs.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
1-
'''
1+
"""
22
This script parse the repository of
33
Nequip input files at
44
https://github.com/mir-group/nequip-input-files.
55
We extract the node / edge hidden features representations.
6-
'''
6+
"""
7+
8+
import os
9+
import yaml
710

8-
import os, yaml
911

1012
def process_nequip_configs():
1113
nequip_files = []
12-
for root, dirs, files in os.walk('../data/nequip-input-files'):
14+
for root, dirs, files in os.walk("../data/nequip-input-files"):
1315
for file in files:
14-
if file.endswith('.yaml'):
16+
if file.endswith(".yaml"):
1517
nequip_files.append(os.path.join(root, file))
16-
18+
1719
irrep_pairs = []
1820
configs = []
1921
for file in nequip_files:
20-
with open(file, 'r') as f:
22+
with open(file, "r") as f:
2123
data = yaml.unsafe_load(f)
2224
filename = os.path.splitext(os.path.basename(file))[0]
23-
feature_irreps_hidden = data['feature_irreps_hidden']
24-
irreps_edge_sh = data['irreps_edge_sh']
25+
feature_irreps_hidden = data["feature_irreps_hidden"]
26+
irreps_edge_sh = data["irreps_edge_sh"]
2527
if (feature_irreps_hidden, irreps_edge_sh) not in irrep_pairs:
2628
irrep_pairs.append((feature_irreps_hidden, irreps_edge_sh))
2729
configs.append((feature_irreps_hidden, irreps_edge_sh, filename))
@@ -30,5 +32,5 @@ def process_nequip_configs():
3032
print(config)
3133

3234

33-
if __name__ == '__main__':
34-
process_nequip_configs()
35+
if __name__ == "__main__":
36+
process_nequip_configs()

openequivariance/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
1+
# ruff: noqa: F401
12
import openequivariance.extlib
23
from pathlib import Path
34
from importlib.metadata import version
45

56
from openequivariance.implementations.e3nn_lite import TPProblem, Irreps
6-
from openequivariance.implementations.TensorProduct import TensorProduct
7-
from openequivariance.implementations.convolution.TensorProductConv import TensorProductConv
7+
from openequivariance.implementations.TensorProduct import TensorProduct
8+
from openequivariance.implementations.convolution.TensorProductConv import (
9+
TensorProductConv,
10+
)
811
from openequivariance.implementations.utils import torch_to_oeq_dtype
912

13+
__all__ = [
14+
"TPProblem",
15+
"Irreps",
16+
"TensorProduct",
17+
"TensorProductConv",
18+
"torch_to_oeq_dtype",
19+
]
20+
1021
__version__ = version("openequivariance")
1122

23+
1224
def _check_package_editable():
1325
import json
1426
from importlib.metadata import Distribution
27+
1528
direct_url = Distribution.from_name("openequivariance").read_text("direct_url.json")
1629
return json.loads(direct_url).get("dir_info", {}).get("editable", False)
1730

18-
_editable_install_output_path = Path(__file__).parent.parent / "outputs"
31+
32+
_editable_install_output_path = Path(__file__).parent.parent / "outputs"

0 commit comments

Comments
 (0)