Skip to content

Commit 5a5080a

Browse files
committed
Even more refactoring.
1 parent aa9bbf6 commit 5a5080a

20 files changed

Lines changed: 267 additions & 432 deletions

openequivariance/openequivariance/_torch/CUETensorProduct.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from openequivariance.core.TensorProductBase import TensorProductBase
88
from openequivariance.core.e3nn_lite import TPProblem
9-
from openequivariance.benchmark.logging import getLogger
10-
from openequivariance.benchmark.tpp_creation_utils import (
9+
from openequivariance.core.logging import getLogger
10+
from openequivariance.benchmark.problems import (
1111
ChannelwiseTPP,
1212
FullyConnectedTPProblem,
1313
SingleInstruction,
1414
)
15-
from openequivariance.core.utils import count_cg_non_zero
1615

1716
os.environ["CUEQUIVARIANCE_OPS_USE_JIT"] = "1"
1817

@@ -235,57 +234,6 @@ def benchmark_backward(
235234
kernel_names=self.kernel_names,
236235
)
237236

238-
# Copied over from loop unroller to match arithmetic intensity on roofline plots
239-
def calculate_flops_forward(self, batch_size: int) -> dict:
240-
if self.is_uvw:
241-
return super().calculate_flops_forward(batch_size)
242-
else:
243-
tpp = self.config
244-
flop_count = {
245-
"CG_decomposition": 0,
246-
"linear_combination": 0,
247-
"outer_products": 0,
248-
}
249-
for ins in tpp.instructions:
250-
l1, l2, l3 = (
251-
tpp.irreps_in1[ins.i_in1].ir.l,
252-
tpp.irreps_in2[ins.i_in2].ir.l,
253-
tpp.irreps_out[ins.i_out].ir.l,
254-
)
255-
flop_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
256-
ins.path_shape[0] * ins.path_shape[1]
257-
)
258-
flop_count["linear_combination"] += (
259-
(2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
260-
)
261-
262-
flop_count["CG_decomposition"] *= 3 * batch_size
263-
flop_count["linear_combination"] *= (
264-
batch_size # Weights do not require FMA here
265-
)
266-
flop_count["total"] = sum(flop_count.values())
267-
return flop_count
268-
269-
def calculate_flops_backward(self, batch_size: int) -> dict:
270-
if self.is_uvw:
271-
return super().calculate_flops_backward(batch_size)
272-
else:
273-
tpp = self.config
274-
flop_count = {"backward": 0}
275-
for ins in tpp.instructions:
276-
l1, l2, l3 = (
277-
tpp.irreps_in1[ins.i_in1].ir.l,
278-
tpp.irreps_in2[ins.i_in2].ir.l,
279-
tpp.irreps_out[ins.i_out].ir.l,
280-
)
281-
flop_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
282-
ins.path_shape[0] * ins.path_shape[1]
283-
)
284-
285-
flop_count["backward"] *= 9 * batch_size
286-
flop_count["total"] = sum(flop_count.values())
287-
return flop_count
288-
289237
@staticmethod
290238
def name():
291239
return "CUETensorProduct"

openequivariance/openequivariance/_torch/E3NNTensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from openequivariance.core.TensorProductBase import TensorProductBase
1313
from openequivariance.core.e3nn_lite import TPProblem
14-
from openequivariance.benchmark.logging import getLogger
14+
from openequivariance.core.logging import getLogger
1515
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixin
1616

1717
TORCH_COMPILE_AUTOTUNING_DIR = pathlib.Path("triton_autotuning")

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from openequivariance._torch import extlib
44
import torch
55
from openequivariance.core.utils import torch_to_oeq_dtype, dtype_to_enum
6-
from openequivariance.benchmark.logging import getLogger
6+
from openequivariance.core.logging import getLogger
77
from openequivariance._torch.utils import (
88
reorder_torch,
99
string_to_tensor,

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
enum_to_torch_dtype,
2424
)
2525

26-
from openequivariance.benchmark.logging import getLogger
26+
from openequivariance.core.logging import getLogger
2727
from openequivariance._torch.NPDoubleBackwardMixin import NumpyDoubleBackwardMixinConv
2828

2929
logger = getLogger()

openequivariance/openequivariance/_torch/extlib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import torch
1010

11-
from openequivariance.benchmark.logging import getLogger
11+
from openequivariance.core.logging import getLogger
1212

1313
oeq_root = str(Path(__file__).parent.parent.parent)
1414

openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
correctness_double_backward_conv,
1212
correctness_forward_conv,
1313
)
14-
from openequivariance.benchmark.logging import getLogger
14+
from openequivariance.core.logging import getLogger
1515
from openequivariance.core.ConvolutionBase import CoordGraph
1616
from openequivariance.benchmark.benchmark_utils import NpEncoder
1717

openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from openequivariance._torch.extlib import DeviceProp
1111
from openequivariance.core.TensorProductBase import TensorProductBase
1212

13-
from openequivariance.benchmark.logging import getLogger, bcolors
13+
from openequivariance.core.logging import getLogger, bcolors
1414
from openequivariance.core.e3nn_lite import TPProblem
1515
from openequivariance.benchmark.correctness import (
1616
correctness_forward,

openequivariance/openequivariance/benchmark/benchmark_utils.py

Lines changed: 19 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,22 @@
11
import json
22
import numpy as np
33

4-
from openequivariance.benchmark.random_buffer_utils import (
4+
from openequivariance.benchmark.test_buffers import (
55
get_random_buffers_forward,
66
get_random_buffers_backward,
77
get_random_buffers_double_backward,
88
)
9-
from openequivariance.benchmark.perf_metrics_utils import (
10-
calculate_minimum_flops_forward,
11-
calculate_minimum_memory_streamed_forward,
12-
calculate_minimum_memory_streamed_backward,
9+
from openequivariance.benchmark.metrics import (
10+
flops_forward,
11+
flops_backward,
12+
memory_streamed_forward,
13+
memory_streamed_backward,
1314
)
1415
from openequivariance.core.utils import calculate_total_nnz
1516
from openequivariance.core.TensorProductBase import TensorProductBase
1617
from openequivariance.core.e3nn_lite import TPProblem
1718
from openequivariance._torch.CUETensorProduct import CUETensorProduct
18-
from openequivariance.benchmark.logging import getLogger, bcolors
19+
from openequivariance.core.logging import getLogger, bcolors
1920

2021
logger = getLogger()
2122

@@ -110,24 +111,12 @@ def benchmark_forward(
110111
time_millis = np.full(shape=num_iter, fill_value=-1)
111112

112113
# FLOPS
113-
try:
114-
flops = tp.calculate_flops_forward(batch_size=batch_size)
115-
except NotImplementedError:
116-
logger.warning(
117-
"Actual flop count not calculated, so minimum values are being used"
118-
)
119-
flops = calculate_minimum_flops_forward(problem, batch_size=batch_size)
114+
flops = flops_forward(problem, batch_size=batch_size)
120115

121116
# DATA
122-
try:
123-
memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size)
124-
except NotImplementedError:
125-
logger.warning(
126-
"Actual memory streamed not calculated, so minimum values are being used"
127-
)
128-
memory_streamed = calculate_minimum_memory_streamed_forward(
129-
problem, batch_size=batch_size
130-
)
117+
memory_streamed = memory_streamed_forward(
118+
problem, batch_size=batch_size
119+
)
131120

132121
result |= calculate_performance_statistics(
133122
problem=problem,
@@ -181,29 +170,11 @@ def benchmark_backward(
181170
)
182171
time_millis = np.full(shape=num_iter, fill_value=-1)
183172

184-
try:
185-
flops = tp.calculate_flops_backward(batch_size=batch_size)
186-
except NotImplementedError:
187-
try:
188-
flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size)
189-
logger.warning(
190-
"Actual flops was not calculated, so minimum values are being used"
191-
)
192-
except NotImplementedError:
193-
logger.warning(
194-
"Minimum Backwards flops calculations are not implemented, -1 is a placeholder"
195-
)
196-
flops = {"total": -1}
173+
flops = flops_backward(tpp=problem, batch_size=batch_size)
197174

198-
try:
199-
memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size)
200-
except NotImplementedError:
201-
logger.warning(
202-
"Actual memory streamed was not calculated, so minimum values are being"
203-
)
204-
memory_streamed = calculate_minimum_memory_streamed_backward(
205-
tpp=problem, batch_size=batch_size
206-
)
175+
memory_streamed = memory_streamed_backward(
176+
tpp=problem, batch_size=batch_size
177+
)
207178

208179
result |= calculate_performance_statistics(
209180
problem=problem,
@@ -258,29 +229,11 @@ def benchmark_double_backward(
258229
)
259230
time_millis = np.full(shape=num_iter, fill_value=-1)
260231

261-
try:
262-
flops = tp.calculate_flops_backward(batch_size=batch_size)
263-
except NotImplementedError:
264-
try:
265-
flops = calculate_minimum_flops_forward(tpp=problem, batch_size=batch_size)
266-
logger.warning(
267-
"Actual flops was not calculated, so minimum values are being used"
268-
)
269-
except NotImplementedError:
270-
logger.warning(
271-
"Minimum Backwards flops calculations are not implemented, -1 is a placeholder"
272-
)
273-
flops = {"total": -1}
232+
flops = flops_backward(tpp=problem, batch_size=batch_size)
274233

275-
try:
276-
memory_streamed = tp.calculate_memory_streamed_backward(batch_size=batch_size)
277-
except NotImplementedError:
278-
logger.warning(
279-
"Actual memory streamed was not calculated, so minimum values are being"
280-
)
281-
memory_streamed = calculate_minimum_memory_streamed_backward(
282-
tpp=problem, batch_size=batch_size
283-
)
234+
memory_streamed = memory_streamed_backward(
235+
tpp=problem, batch_size=batch_size
236+
)
284237

285238
result |= calculate_performance_statistics(
286239
problem=problem,

openequivariance/openequivariance/benchmark/correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import numpy.linalg as la
66

77
from openequivariance._torch.CUETensorProduct import CUETensorProduct
8-
from openequivariance.benchmark.logging import bcolors, getLogger
9-
from openequivariance.benchmark.random_buffer_utils import (
8+
from openequivariance.core.logging import bcolors, getLogger
9+
from openequivariance.benchmark.test_buffers import (
1010
get_random_buffers_backward_conv,
1111
get_random_buffers_backward,
1212
get_random_buffers_double_backward_conv,

openequivariance/openequivariance/benchmark/perf_metrics_utils.py renamed to openequivariance/openequivariance/benchmark/metrics.py

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
1-
import math
2-
31
from openequivariance.core.utils import (
42
count_cg_non_zero,
5-
sparse_outer_product_work,
63
)
74

8-
from openequivariance.core.e3nn_lite import TPProblem, wigner_3j
9-
from openequivariance.benchmark.logging import getLogger
5+
from openequivariance.core.e3nn_lite import TPProblem
6+
from openequivariance.core.logging import getLogger
107
import numpy as np
118

129
logger = getLogger()
1310

1411

15-
def calculate_minimum_memory_streamed_forward(
12+
def memory_streamed_forward(
1613
tpp: TPProblem, batch_size: int
1714
) -> dict[str, int]:
1815
"""
@@ -31,7 +28,7 @@ def calculate_minimum_memory_streamed_forward(
3128
return data_size
3229

3330

34-
def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict:
31+
def memory_streamed_backward(tpp: TPProblem, batch_size: int) -> dict:
3532
"""
3633
This represents an absolute minimum amount of memory that could be streamed on an ideal machine
3734
It returns the number of bytes streamed total and from each source
@@ -51,46 +48,51 @@ def calculate_minimum_memory_streamed_backward(tpp: TPProblem, batch_size: int)
5148
return data_size
5249

5350

54-
def calculate_minimum_flops_forward(tpp: TPProblem, batch_size: int) -> dict:
51+
def flops_forward(tpp: TPProblem, batch_size: int) -> dict:
5552
"""
56-
This is not actually calcuating the minimum value.
57-
Ideally you might share the outer product values between two inputs across multiple inputs.
58-
This is assuming that you form those values and reuse them once per CG decomp.
53+
Default FLOP estimate aligned with LoopUnrollTP's forward FLOP accounting.
5954
"""
60-
logger.warning("Minimum flops Calculation is not the true minimum")
61-
flops_count = {}
62-
flops_count["outer_products"] = 0
63-
flops_count["CG_decomposition"] = 0
64-
flops_count["linear_combination"] = 0
55+
flops_count = {"CG_decomposition": 0, "linear_combination": 0, "outer_products": 0}
56+
6557
for ins in tpp.instructions: # type : Instruction
6658
l1, l2, l3 = (
6759
tpp.irreps_in1[ins.i_in1].ir.l,
6860
tpp.irreps_in2[ins.i_in2].ir.l,
6961
tpp.irreps_out[ins.i_out].ir.l,
7062
)
7163

72-
flops_count["outer_products"] += sparse_outer_product_work(
73-
wigner_3j(l1, l2, l3)
74-
)
7564
flops_count["CG_decomposition"] += count_cg_non_zero(l1, l2, l3) * (
7665
ins.path_shape[0] * ins.path_shape[1]
7766
)
7867
flops_count["linear_combination"] += (
79-
(2 * l3 + 1) * math.prod(ins.path_shape) if ins.has_weight else 0
68+
(2 * l3 + 1) * np.prod(ins.path_shape) if ins.has_weight else 0
8069
)
8170

82-
flops_count["outer_products"] *= batch_size
83-
flops_count["CG_decomposition"] *= 2 * batch_size
84-
flops_count["linear_combination"] *= 2 * batch_size
71+
flops_count["CG_decomposition"] *= 3 * batch_size
72+
flops_count["linear_combination"] *= (
73+
batch_size # Weights do not require FMA here
74+
)
8575

8676
flops_count["total"] = sum(flops_count.values())
8777
return flops_count
8878

8979

90-
def calculate_minimum_flops_backward(tpp: TPProblem, batch_size: int) -> dict:
80+
def flops_backward(tpp: TPProblem, batch_size: int) -> dict:
9181
"""
92-
This is not actually calcuating the minumum value.
93-
Ideally you might share the outer product values between two inputs across multiple inputs.
94-
This is assuming that you form those values and reuse them once per CG decomp.
82+
Default FLOP estimate aligned with LoopUnrollTP's backward FLOP accounting.
9583
"""
96-
raise NotImplementedError("this needs to be implemented properly")
84+
flops_count = {"backward": 0}
85+
86+
for ins in tpp.instructions: # type : Instruction
87+
l1, l2, l3 = (
88+
tpp.irreps_in1[ins.i_in1].ir.l,
89+
tpp.irreps_in2[ins.i_in2].ir.l,
90+
tpp.irreps_out[ins.i_out].ir.l,
91+
)
92+
flops_count["backward"] += count_cg_non_zero(l1, l2, l3) * (
93+
ins.path_shape[0] * ins.path_shape[1]
94+
)
95+
96+
flops_count["backward"] *= 9 * batch_size
97+
flops_count["total"] = sum(flops_count.values())
98+
return flops_count

0 commit comments

Comments
 (0)