Skip to content

Commit e40d21f

Browse files
Kahan Summation Introduced (#98)
* First pass at Kahan summation. * Added SMEM capacity exception. * Kahan summation working for the forward pass. * Kahan summation for the backward pass is working. * Forward A kernel is working with Kahan summation. * Kahan summation passes a simple correctness check. * Added a small suite of UVU tests. * First step writing code to test Kahan summation accuracy. * Kahan summation reduces error by about 5x on Kahan summation. * We have a dependency cycle that we need to fix. * Removed some reference cycles. * Accuracy benchmark completed. * Kahan summation complete!
1 parent 73aa9b6 commit e40d21f

11 files changed

Lines changed: 330 additions & 173 deletions

File tree

README.md

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,17 @@ print(torch.norm(Z))
125125
```
126126
**Note**: you don't need Pytorch geometric to use our kernels. When
127127
`deterministic=False`, the `sender` and `receiver` indices can have
128-
arbitrary order.
128+
arbitrary order.
129+
130+
**New:** If you're working in FP32 precision and want
131+
higher accuracy during graph convolution, we offer a Kahan
132+
summation variant of our deterministic algorithm:
133+
134+
```python
135+
tp_conv_kahan = oeq.TensorProductConv(problem, torch_op=True, deterministic=True, kahan=True)
136+
Z = tp_conv_kahan.forward(X, Y[receiver_perm], W[receiver_perm], edge_index[0], edge_index[1], sender_perm)
137+
print(torch.norm(Z))
138+
```
129139

130140
## Installation
131141
We currently support Linux systems only.
@@ -172,6 +182,7 @@ python tests/benchmark.py -o outputs/uvu uvu --plot
172182
python tests/benchmark.py -o outputs/uvw uvw --plot
173183
python tests/benchmark.py -o outputs/roofline roofline --plot
174184
python tests/benchmark.py -o outputs/conv conv --plot --data data/molecular_structures
185+
python tests/benchmark.py -o outputs/kahan_conv kahan_conv --data data/molecular_structures/
175186
```
176187

177188
If your GPU has limited memory, you might want to try

openequivariance/benchmark/ConvBenchmarkSuite.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,22 @@ def __init__(self, configs,
2424
num_warmup = 10,
2525
num_iter = 30,
2626
reference_impl=None,
27-
torch_op=True,
2827
test_name=None,
29-
prng_seed = 12345):
28+
prng_seed = 12345,
29+
correctness_threshold = 1e-5):
3030
self.configs = configs
3131
self.num_warmup = num_warmup
3232
self.num_iter = num_iter
3333
self.reference_impl = reference_impl
3434
self.prng_seed = 12345
35-
self.correctness_threshold = 1e-5
36-
self.torch_op = torch_op
35+
self.correctness_threshold = correctness_threshold
3736
self.exp_count = 0
3837
self.test_name = test_name
3938

4039
self.millis_since_epoch = round(time.time() * 1000)
4140

42-
def run(self, graph, implementations, direction, output_folder=None, correctness=True, double_backward_correctness=False, benchmark=True):
41+
def run(self, graph, implementations, direction, output_folder=None,
42+
correctness=True, benchmark=True, high_precision_ref=False):
4343
if output_folder is None:
4444
if oeq._check_package_editable():
4545
output_folder = oeq._editable_install_output_path / f"{self.millis_since_epoch}"
@@ -65,20 +65,15 @@ def run(self, graph, implementations, direction, output_folder=None, correctness
6565
for impl in implementations:
6666
tc_name = f"{config}, {impl.name()}"
6767
logger.info(f'Starting {tc_name}, graph {graph.name}, {direction}')
68-
conv = impl(config, torch_op=self.torch_op)
69-
70-
if double_backward_correctness:
71-
double_backward_correctness = conv.test_correctness_double_backward(self.graph,
72-
thresh=self.correctness_threshold,
73-
prng_seed=self.prng_seed,
74-
reference_implementation=self.reference_impl)
68+
conv = impl(config)
7569

7670
if direction == "forward":
7771
if correctness:
7872
correctness = conv.test_correctness_forward(graph,
7973
thresh=self.correctness_threshold,
8074
prng_seed=self.prng_seed,
81-
reference_implementation=self.reference_impl)
75+
reference_implementation=self.reference_impl,
76+
high_precision_ref=high_precision_ref)
8277

8378
if benchmark:
8479
benchmark = conv.benchmark_forward(self.num_warmup,
@@ -90,23 +85,33 @@ def run(self, graph, implementations, direction, output_folder=None, correctness
9085
correctness = conv.test_correctness_backward(graph,
9186
thresh=self.correctness_threshold,
9287
prng_seed=self.prng_seed,
93-
reference_implementation=self.reference_impl)
88+
reference_implementation=self.reference_impl,
89+
high_precision_ref=high_precision_ref)
9490

9591
if benchmark:
9692
benchmark = conv.benchmark_backward(self.num_warmup,
9793
self.num_iter, graph, prng_seed=12345)
94+
95+
if direction == "double_backward":
96+
if correctness:
97+
correctness = conv.test_correctness_double_backward(self.graph,
98+
thresh=self.correctness_threshold,
99+
prng_seed=self.prng_seed,
100+
reference_implementation=self.reference_impl,
101+
high_precision_ref=high_precision_ref)
102+
103+
assert not benchmark
98104

99105
result = {
100106
"config": str(config),
101107
"irrep_dtype": str(config.irrep_dtype),
102108
"weight_dtype": str(config.weight_dtype),
103-
"torch_overhead_included": self.torch_op,
109+
"torch_overhead_included": conv.torch_op,
104110
"direction": direction,
105111
"graph": graph.name,
106112
"name": impl.name(),
107113
"correctness": correctness,
108-
"benchmark": benchmark,
109-
"double_backward_correctness": double_backward_correctness
114+
"benchmark": benchmark
110115
}
111116

112117
fname = pathlib.Path(f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json")

openequivariance/implementations/ComputationSchedule.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from openequivariance.implementations.TensorProductBase import *
66
logger = getLogger()
77

8+
class SMEMCapacityException(Exception):
9+
def __init__(self, message):
10+
self.message = message
11+
super().__init__(self.message)
12+
813
class IrrepMapping:
914
'''
1015
Maps irreps from a source to a destination set.
@@ -104,7 +109,7 @@ def create_schedule_case2(instructions, memory_per_warp, calculate_smem, directi
104109
segments.append((cL1, cL2, cL3, cinst))
105110
cL3, cinst = set(), []
106111
else:
107-
raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
112+
raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
108113
else:
109114
cL3.add(w)
110115
cinst.append(inst_idx)
@@ -130,7 +135,7 @@ def create_schedule_case3(instructions, memory_per_warp, calculate_smem, directi
130135
segments.append((cL1, cL2, cL3, cinst))
131136
cL1, cL2, cL3, cinst = set(), set(), set(), []
132137
else:
133-
raise Exception(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
138+
raise SMEMCapacityException(f"{direction.title()} scheduling failed, memory allocation too small to accomodate segment!")
134139
else:
135140
cL1.add(u)
136141
cL2.add(v)
@@ -245,10 +250,15 @@ def __init__(self,
245250
weight_dtype,
246251
include_scratch=False,
247252
stream_weights=False,
248-
schedule_type=2):
253+
schedule_type=2,
254+
kahan=False):
249255
'''
250256
smem_limit: size of available shared memory in bytes
251257
'''
258+
self.kahan = kahan
259+
if kahan:
260+
assert irrep_dtype == weight_dtype == np.float32
261+
252262
# Note: does not work with variances for irreps; easy to add that in
253263
self.total_warps = warps_per_block * block_count
254264

@@ -288,10 +298,16 @@ def calculate_forward_smem(L1_set, L2_set, L3_set, inst_idxs):
288298
"L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
289299
"L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
290300
"L3": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
301+
"L3_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr},
291302
"weights": {"size": 0, "dtype": self.weight_dtype_cstr},
292303
"scratch": {"size": 0, "dtype": self.weight_dtype_cstr}
293304
}
294305

306+
if kahan:
307+
smem["L3_kahan"]["size"] = smem["L3"]["size"]
308+
else:
309+
smem.pop("L3_kahan")
310+
295311
weights_smem = 0
296312
for inst_idx in inst_idxs:
297313
inst = self.new_instructions[inst_idx]
@@ -325,6 +341,7 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
325341
smem = {
326342
"L1": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
327343
"L1_grad": {"size": sum([self.L1[el].dim for el in L1_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
344+
"L1_kahan": {"size": 0, "dtype": self.irrep_dtype_cstr},
328345
"L2": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
329346
"L2_grad": {"size": sum([self.L2[el].dim for el in L2_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
330347
"L3_grad": {"size": sum([self.L3[el].dim for el in L3_set]) * irrep_itemsize, "dtype": self.irrep_dtype_cstr},
@@ -333,6 +350,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
333350
"scratch": {"size": 0, "dtype": self.weight_dtype_cstr}
334351
}
335352

353+
if kahan:
354+
smem["L1_kahan"]["size"] = smem["L1"]["size"]
355+
else:
356+
smem.pop("L1_kahan")
357+
336358
if L2_dgrad:
337359
smem["L2_dgrad"] = {"size": smem["L2"]["size"], "dtype": self.irrep_dtype_cstr}
338360

@@ -376,11 +398,11 @@ def calculate_backward_smem(L1_set, L2_set, L3_set, inst_idxs,
376398
schedule2_succeeded = False
377399
try:
378400
if schedule_type != 2:
379-
raise Exception("Asked for schedule case 3.")
401+
raise SMEMCapacityException("Asked for schedule case 3.")
380402
self.segments = create_schedule_case2(self.new_instructions, self.memory_per_warp, calculate_smem, direction)
381403
logger.info(f"{direction.title()} case 2 scheduling succeeded with {len(self.segments)} segments.")
382404
schedule2_succeeded = True
383-
except Exception as e:
405+
except SMEMCapacityException as e:
384406
self.segments = create_schedule_case3(self.new_instructions, self.memory_per_warp, calculate_smem, direction)
385407
logger.info(f"{direction.title()} case 3 scheduling succeeded with {len(self.segments)} segments.")
386408

0 commit comments

Comments
 (0)