Skip to content

Commit aa9bbf6

Browse files
committed
Ruff.
1 parent 9de117e commit aa9bbf6

6 files changed

Lines changed: 102 additions & 104 deletions

File tree

openequivariance/openequivariance/benchmark/correctness.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from openequivariance.core.e3nn_lite import TPProblem
1818
from openequivariance.core.TensorProductBase import TensorProductBase
19-
from openequivariance.core.utils import IrrepLayoutUtils
19+
from openequivariance.core.utils import transpose_irrep_layout
2020

2121
logger = getLogger()
2222

@@ -87,29 +87,31 @@ def correctness_forward(
8787
outputs = []
8888

8989
for i, impl in enumerate([test_implementation, reference_implementation]):
90-
is_test_impl = (i == 0)
90+
is_test_impl = i == 0
9191
tp = instantiate_implementation(impl, problem)
9292
uses_cue = impl == CUETensorProduct or isinstance(tp, CUETensorProduct)
93-
run_in1, run_in2, run_weights, run_out = [ buf.copy() for buf in (in1, in2, weights, out) ]
93+
run_in1, run_in2, run_weights, run_out = [
94+
buf.copy() for buf in (in1, in2, weights, out)
95+
]
9496

9597
if problem.shared_weights and uses_cue:
9698
run_weights = run_weights[np.newaxis, :]
9799

98-
# Transpose inputs, if necessary, for the test implementation
100+
# Transpose inputs, if necessary, for the test implementation
99101
if is_test_impl:
100102
run_in1, run_in2 = [
101-
IrrepLayoutUtils.transpose_irrep_layout(
102-
arr, irreps, "mul_ir", tp.config.layout
103-
) for arr, irreps in zip(
104-
(run_in1, run_in2),
105-
(problem.irreps_in1, problem.irreps_in2)
103+
transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout)
104+
for arr, irreps in zip(
105+
(run_in1, run_in2), (problem.irreps_in1, problem.irreps_in2)
106106
)
107107
]
108108

109-
tp.forward_cpu(L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights)
109+
tp.forward_cpu(
110+
L1_in=run_in1, L2_in=run_in2, L3_out=run_out, weights=run_weights
111+
)
110112

111113
if is_test_impl:
112-
run_out = IrrepLayoutUtils.transpose_irrep_layout(
114+
run_out = transpose_irrep_layout(
113115
run_out, problem.irreps_out, tp.config.layout, "mul_ir"
114116
)
115117

@@ -147,7 +149,15 @@ def correctness_backward(
147149
is_test_impl = i == 0
148150
tp = instantiate_implementation(impl, problem)
149151

150-
run_in1, run_in2, run_L3_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [
152+
(
153+
run_in1,
154+
run_in2,
155+
run_L3_grad,
156+
run_weights,
157+
run_weights_grad,
158+
run_in1_grad,
159+
run_in2_grad,
160+
) = [
151161
buf.copy()
152162
for buf in (in1, in2, out_grad, weights, weights_grad, in1_grad, in2_grad)
153163
]
@@ -159,9 +169,7 @@ def correctness_backward(
159169

160170
if is_test_impl:
161171
run_in1, run_in2, run_L3_grad = [
162-
IrrepLayoutUtils.transpose_irrep_layout(
163-
arr, irreps, "mul_ir", tp.config.layout
164-
)
172+
transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout)
165173
for arr, irreps in zip(
166174
(run_in1, run_in2, run_L3_grad),
167175
(problem.irreps_in1, problem.irreps_in2, problem.irreps_out),
@@ -180,9 +188,7 @@ def correctness_backward(
180188

181189
if is_test_impl:
182190
run_in1_grad, run_in2_grad = [
183-
IrrepLayoutUtils.transpose_irrep_layout(
184-
arr, irreps, tp.config.layout, "mul_ir"
185-
)
191+
transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir")
186192
for arr, irreps in zip(
187193
(run_in1_grad, run_in2_grad),
188194
(problem.irreps_in1, problem.irreps_in2),
@@ -254,9 +260,7 @@ def correctness_double_backward(
254260

255261
if is_test_impl:
256262
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
257-
IrrepLayoutUtils.transpose_irrep_layout(
258-
arr, irreps, "mul_ir", tp.config.layout
259-
)
263+
transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout)
260264
for arr, irreps in zip(
261265
(db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad),
262266
(
@@ -281,9 +285,7 @@ def correctness_double_backward(
281285

282286
if is_test_impl:
283287
out_dgrad, in1_grad, in2_grad = [
284-
IrrepLayoutUtils.transpose_irrep_layout(
285-
arr, irreps, tp.config.layout, "mul_ir"
286-
)
288+
transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir")
287289
for arr, irreps in zip(
288290
(out_dgrad, in1_grad, in2_grad),
289291
(problem.irreps_out, problem.irreps_in1, problem.irreps_in2),
@@ -359,9 +361,7 @@ def correctness_forward_conv(
359361

360362
if is_test_impl:
361363
run_in1, run_in2 = [
362-
IrrepLayoutUtils.transpose_irrep_layout(
363-
arr, irreps, "mul_ir", conv.config.layout
364-
)
364+
transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout)
365365
for arr, irreps in zip(
366366
(run_in1, run_in2),
367367
(conv.config.irreps_in1, conv.config.irreps_in2),
@@ -375,7 +375,7 @@ def correctness_forward_conv(
375375
graph=graph,
376376
)
377377

378-
run_out = IrrepLayoutUtils.transpose_irrep_layout(
378+
run_out = transpose_irrep_layout(
379379
run_out, conv.config.irreps_out, conv.config.layout, "mul_ir"
380380
)
381381
else:
@@ -410,13 +410,9 @@ def correctness_forward_conv(
410410

411411
for _ in range(num_trials):
412412
repeated_run = out.copy()
413-
rep_in1, rep_in2, rep_weights = [
414-
buf.copy() for buf in (in1, in2, weights)
415-
]
413+
rep_in1, rep_in2, rep_weights = [buf.copy() for buf in (in1, in2, weights)]
416414
rep_in1, rep_in2 = [
417-
IrrepLayoutUtils.transpose_irrep_layout(
418-
arr, irreps, "mul_ir", conv.config.layout
419-
)
415+
transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout)
420416
for arr, irreps in zip(
421417
(rep_in1, rep_in2),
422418
(conv.config.irreps_in1, conv.config.irreps_in2),
@@ -430,7 +426,7 @@ def correctness_forward_conv(
430426
graph=graph,
431427
)
432428

433-
repeated_run = IrrepLayoutUtils.transpose_irrep_layout(
429+
repeated_run = transpose_irrep_layout(
434430
repeated_run, conv.config.irreps_out, conv.config.layout, "mul_ir"
435431
)
436432

@@ -471,9 +467,15 @@ def correctness_backward_conv(
471467
is_test_impl = i == 0
472468
tp = impl if is_test_impl else impl(reference_problem)
473469

474-
run_in1, run_in2, run_out_grad, run_weights, run_weights_grad, run_in1_grad, run_in2_grad = [
475-
buf.copy() for buf in buffers
476-
]
470+
(
471+
run_in1,
472+
run_in2,
473+
run_out_grad,
474+
run_weights,
475+
run_weights_grad,
476+
run_in1_grad,
477+
run_in2_grad,
478+
) = [buf.copy() for buf in buffers]
477479

478480
if not is_test_impl and high_precision_ref:
479481
(
@@ -488,12 +490,14 @@ def correctness_backward_conv(
488490

489491
if is_test_impl:
490492
run_in1, run_in2, run_out_grad = [
491-
IrrepLayoutUtils.transpose_irrep_layout(
492-
arr, irreps, "mul_ir", conv.config.layout
493-
)
493+
transpose_irrep_layout(arr, irreps, "mul_ir", conv.config.layout)
494494
for arr, irreps in zip(
495495
(run_in1, run_in2, run_out_grad),
496-
(conv.config.irreps_in1, conv.config.irreps_in2, conv.config.irreps_out),
496+
(
497+
conv.config.irreps_in1,
498+
conv.config.irreps_in2,
499+
conv.config.irreps_out,
500+
),
497501
)
498502
]
499503

@@ -510,9 +514,7 @@ def correctness_backward_conv(
510514

511515
if is_test_impl:
512516
run_in1_grad, run_in2_grad = [
513-
IrrepLayoutUtils.transpose_irrep_layout(
514-
arr, irreps, conv.config.layout, "mul_ir"
515-
)
517+
transpose_irrep_layout(arr, irreps, conv.config.layout, "mul_ir")
516518
for arr, irreps in zip(
517519
(run_in1_grad, run_in2_grad),
518520
(conv.config.irreps_in1, conv.config.irreps_in2),
@@ -581,9 +583,7 @@ def correctness_double_backward_conv(
581583
]
582584
if is_test_impl:
583585
db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad = [
584-
IrrepLayoutUtils.transpose_irrep_layout(
585-
arr, irreps, "mul_ir", tp.config.layout
586-
)
586+
transpose_irrep_layout(arr, irreps, "mul_ir", tp.config.layout)
587587
for arr, irreps in zip(
588588
(db_in1, db_in2, db_out_grad, db_in1_dgrad, db_in2_dgrad),
589589
(
@@ -609,9 +609,7 @@ def correctness_double_backward_conv(
609609

610610
if is_test_impl:
611611
out_dgrad, in1_grad, in2_grad = [
612-
IrrepLayoutUtils.transpose_irrep_layout(
613-
arr, irreps, tp.config.layout, "mul_ir"
614-
)
612+
transpose_irrep_layout(arr, irreps, tp.config.layout, "mul_ir")
615613
for arr, irreps in zip(
616614
(out_dgrad, in1_grad, in2_grad),
617615
(tp.config.irreps_out, tp.config.irreps_in1, tp.config.irreps_in2),

openequivariance/openequivariance/core/ConvolutionBase.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
from openequivariance.benchmark.logging import bcolors, getLogger
66
from openequivariance.benchmark.random_buffer_utils import (
77
get_random_buffers_backward_conv,
8-
get_random_buffers_double_backward_conv,
98
get_random_buffers_forward_conv,
109
)
1110
from openequivariance.core.e3nn_lite import wigner_3j
12-
from openequivariance.core.utils import IrrepLayoutUtils, benchmark
11+
from openequivariance.core.utils import benchmark
1312

1413
logger = getLogger()
1514

openequivariance/openequivariance/core/utils.py

Lines changed: 47 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -209,50 +209,52 @@ def hash_str_64(s: str) -> int:
209209
return int.from_bytes(hashlib.sha256(s.encode()).digest()[:7], "big")
210210

211211

212-
class IrrepLayoutUtils:
213-
@staticmethod
214-
def transpose_irrep_layout(
215-
array: np.ndarray,
216-
irreps: Irreps,
217-
src_layout: str,
218-
dst_layout: str,
219-
) -> np.ndarray:
220-
"""
221-
Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts.
222-
223-
Expected input shape is `[..., irreps.dim]`. A new array is returned.
224-
If `src_layout == dst_layout`, this returns a copy.
225-
"""
226-
if src_layout not in ("mul_ir", "ir_mul"):
227-
raise ValueError(f"Unsupported src_layout: {src_layout}")
228-
if dst_layout not in ("mul_ir", "ir_mul"):
229-
raise ValueError(f"Unsupported dst_layout: {dst_layout}")
230-
231-
x = np.asarray(array)
232-
out = np.empty_like(x)
233-
234-
if src_layout == dst_layout:
235-
out[...] = x
236-
return out
237-
238-
slices = irreps.slices()
239-
for ir_idx, mul_ir in enumerate(irreps):
240-
mul = mul_ir.mul
241-
dim = mul_ir.ir.dim
242-
seg = slices[ir_idx]
243-
block = x[..., seg.start : seg.stop]
244-
245-
if src_layout == "ir_mul" and dst_layout == "mul_ir":
246-
out[..., seg.start : seg.stop] = block.reshape(
247-
*block.shape[:-1], dim, mul
248-
).swapaxes(-1, -2).reshape(*block.shape[:-1], mul * dim)
249-
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
250-
out[..., seg.start : seg.stop] = block.reshape(
251-
*block.shape[:-1], mul, dim
252-
).swapaxes(-1, -2).reshape(*block.shape[:-1], dim * mul)
253-
else:
254-
raise ValueError(
255-
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
256-
)
212+
def transpose_irrep_layout(
213+
array: np.ndarray,
214+
irreps: Irreps,
215+
src_layout: str,
216+
dst_layout: str,
217+
) -> np.ndarray:
218+
"""
219+
Transpose irrep-packed feature arrays between `mul_ir` and `ir_mul` layouts.
220+
221+
Expected input shape is `[..., irreps.dim]`. A new array is returned.
222+
If `src_layout == dst_layout`, this returns a copy.
223+
"""
224+
if src_layout not in ("mul_ir", "ir_mul"):
225+
raise ValueError(f"Unsupported src_layout: {src_layout}")
226+
if dst_layout not in ("mul_ir", "ir_mul"):
227+
raise ValueError(f"Unsupported dst_layout: {dst_layout}")
257228

229+
x = np.asarray(array)
230+
out = np.empty_like(x)
231+
232+
if src_layout == dst_layout:
233+
out[...] = x
258234
return out
235+
236+
slices = irreps.slices()
237+
for ir_idx, mul_ir in enumerate(irreps):
238+
mul = mul_ir.mul
239+
dim = mul_ir.ir.dim
240+
seg = slices[ir_idx]
241+
block = x[..., seg.start : seg.stop]
242+
243+
if src_layout == "ir_mul" and dst_layout == "mul_ir":
244+
out[..., seg.start : seg.stop] = (
245+
block.reshape(*block.shape[:-1], dim, mul)
246+
.swapaxes(-1, -2)
247+
.reshape(*block.shape[:-1], mul * dim)
248+
)
249+
elif src_layout == "mul_ir" and dst_layout == "ir_mul":
250+
out[..., seg.start : seg.stop] = (
251+
block.reshape(*block.shape[:-1], mul, dim)
252+
.swapaxes(-1, -2)
253+
.reshape(*block.shape[:-1], dim * mul)
254+
)
255+
else:
256+
raise ValueError(
257+
f"Unsupported layout transpose: {src_layout} -> {dst_layout}"
258+
)
259+
260+
return out

tests/batch_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,9 @@ def tp_and_problem(self, problem, extra_tp_constructor_args, with_jax):
272272
tp.to(switch_map[problem.irrep_dtype])
273273
return tp, tp.config
274274

275+
275276
class TestIrMulLayoutMACE(TPCorrectness):
276-
production_model_tpps = mace_problems() + \
277-
[
277+
production_model_tpps = mace_problems() + [
278278
oeq.TPProblem(
279279
"5x5e",
280280
"1x3e",
@@ -293,7 +293,7 @@ class TestIrMulLayoutMACE(TPCorrectness):
293293
internal_weights=False,
294294
label="ir_mul_repr_13x1x13_l535",
295295
),
296-
]
296+
]
297297

298298
@pytest.fixture(params=production_model_tpps, ids=lambda x: x.label, scope="class")
299299
def problem(self, request, dtype):

tests/conv_test.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ def conv_object(self, request, problem, extra_conv_constructor_args):
293293

294294

295295
class TestIrMulLayout(ConvCorrectness):
296-
production_model_tpps = mace_problems() + \
297-
[
296+
production_model_tpps = mace_problems() + [
298297
oeq.TPProblem(
299298
"5x5e",
300299
"1x3e",

tests/input_validation_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,4 +154,4 @@ def test_ir_mul_rejects_uvw_problem(dtype):
154154
)
155155

156156
with pytest.raises(AssertionError, match="layout='ir_mul'"):
157-
TensorProduct(problem)
157+
TensorProduct(problem)

0 commit comments

Comments
 (0)