Skip to content

Commit b0d4445

Browse files
Enables convolution with shared weights (#101)
* Shared weight convolution is working for UVW. * All tests passing. * Shared weight tests added for batch as well. * Modified fallback torch registration to zero out weight gradients.
1 parent 9602816 commit b0d4445

9 files changed

Lines changed: 125 additions & 31 deletions

File tree

openequivariance/benchmark/correctness_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def check_similiarity(name : str, to_check : np.ndarray, ground_truth : np.nda
2121
result["shape_match"] = True
2222
diff_Linf_norm = float(la.norm((ground_truth - to_check).flatten(), ord=np.inf))
2323
result["diff_Linf_norm"] = diff_Linf_norm
24-
result["pass"] = bool(diff_Linf_norm < correctness_threshold)
24+
result["pass"] = bool(diff_Linf_norm < correctness_threshold)
2525

2626
if result["pass"]:
2727
logger.info(f" {bcolors.OKGREEN}{name} correctness check pass. {diff_Linf_norm=:.3e}, {correctness_threshold=} {bcolors.ENDC}")

openequivariance/extension/torch_tp_jit.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ class TorchJITConv : public torch::CustomClassHolder {
218218
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
219219
JITConvImpl<JITKernel> internal;
220220
int64_t L3_dim;
221+
int shared_weights;
221222

222223
TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
223224
fwd_dict(fwd_dict_i.copy()),
@@ -230,7 +231,8 @@ class TorchJITConv : public torch::CustomClassHolder {
230231
to_map(dbl_bwd_dict_i),
231232
to_map(kernel_dims_i)
232233
),
233-
L3_dim(kernel_dims.at("L3_dim")) { }
234+
L3_dim(kernel_dims.at("L3_dim")),
235+
shared_weights(kernel_dims.at("shared_weights")) { }
234236

235237
tuple<tuple<string, string>,
236238
tuple<string, Map_t>,
@@ -341,6 +343,11 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_backward(
341343
torch::Tensor cols_contig = cols.contiguous();
342344
torch::Tensor workspace_contig = workspace.contiguous();
343345
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();
346+
347+
if(jit_instance->shared_weights == 1) {
348+
W_grad.zero_();
349+
}
350+
344351
jit_instance->internal.backward(
345352
data_ptr(L1_in_contig), data_ptr(L1_grad),
346353
data_ptr(L2_in_contig), data_ptr(L2_grad),
@@ -388,6 +395,10 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_doubl
388395
torch::Tensor workspace_contig = workspace.contiguous();
389396
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();
390397

398+
if(jit_instance->shared_weights == 1) {
399+
W_grad.zero_();
400+
}
401+
391402
jit_instance->internal.double_backward(
392403
data_ptr(L1_in_contig), data_ptr(L2_in_contig),
393404
data_ptr(W_contig), data_ptr(L3_grad_contig),

openequivariance/implementations/convolution/ConvolutionBase.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ def backward_helper( L1_in : torch.Tensor, L2_in : torch.Tensor,
697697
L2_grad = torch.empty_like(L2_in)
698698
weights_grad = torch.empty_like(weights)
699699

700+
if self.config.shared_weights:
701+
weights_grad[:] = 0.0
702+
700703
self.internal.backward_rawptrs(
701704
L1_in.contiguous().data_ptr(), L1_grad.data_ptr(),
702705
L2_in.contiguous().data_ptr(), L2_grad.data_ptr(),

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def __init__(self, config, idx_dtype=np.int64,
2121

2222
analysis = filter_and_analyze_problem(config)
2323
self.is_uvw = analysis["is_uvw"]
24-
assert not config.shared_weights, "LoopUnrollConv does not yet support shared weights"
24+
25+
if config.shared_weights:
26+
assert not deterministic, "Deterministic convolution does not support shared weights"
2527

2628
forward_schedule_type = 3
2729
backward_schedule_type = 2
@@ -148,7 +150,8 @@ def generate_double_backward_schedule(warps_per_block):
148150
vars(self.backward_schedule.launch_config),
149151
vars(self.double_backward_schedule.launch_config),
150152
{"L3_dim": self.L3.dim,
151-
"is_uvw": int(self.is_uvw)})
153+
"is_uvw": int(self.is_uvw),
154+
"shared_weights": int(config.shared_weights)})
152155
logger.info("Kernel compiled!")
153156

154157
#with open("scratch.txt", "w") as f:

openequivariance/templates/loop_unroll_batch.cuh

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ __global__ void backward(
8181
IRREP_T* l3_shft = L3_grad + i * {{backward_schedule.L3.dim}} + lane_id;
8282

8383
{%- if not tpp.shared_weights %}
84-
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
85-
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
84+
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
85+
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
8686
{%- else %}
87-
WEIGHT_T* w = weights;
88-
WEIGHT_T* wgrad = weights_grad;
87+
WEIGHT_T* w = weights;
88+
WEIGHT_T* wgrad = weights_grad;
8989
{%- endif %}
9090
WEIGHT_T* weights_shft = w + lane_id;
9191

@@ -128,7 +128,11 @@ __global__ void backward(
128128
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}
129129

130130
{%- if not backward_schedule.stream_weights%}
131-
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
131+
{%- if not tpp.shared_weights %}
132+
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
133+
{%- else %}
134+
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
135+
{%- endif %}
132136
{%- endif %}
133137
} {%- endfor %}
134138
}
@@ -295,7 +299,11 @@ __global__ void double_backward_B(
295299
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}
296300

297301
{% if not schedule.stream_weights%}
298-
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
302+
{%- if not tpp.shared_weights %}
303+
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
304+
{%- else %}
305+
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
306+
{%- endif %}
299307
{% endif %}
300308
}
301309
} {%- endfor %}

openequivariance/templates/loop_unroll_conv_atomic.cuh

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ __global__ void forward(
6868
IRREP_T* l1 = L1_in + col * {{forward_schedule.L1.dim}} + lane_id;
6969
IRREP_T* l2 = L2_in + i * {{forward_schedule.L2.dim}} + lane_id;
7070
IRREP_T* l3 = L3_out + row * {{forward_schedule.L3.dim}} + lane_id;
71-
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
71+
{%- if not tpp.shared_weights %}
72+
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
73+
{%- else %}
74+
WEIGHT_T* w = weights;
75+
{%- endif %}
7276

7377
__syncwarp();
7478
{{ load_ir_segments(segment.L1Map, "l1", "L1_smem", "j") }}
@@ -115,11 +119,11 @@ __global__ void backward(
115119
IRREP_T* l3_shft = L3_grad + row * {{backward_schedule.L3.dim}} + lane_id;
116120

117121
{%- if not tpp.shared_weights %}
118-
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
119-
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
122+
WEIGHT_T* w = weights + i * {{tpp.weight_numel}};
123+
WEIGHT_T* wgrad = weights_grad + i * {{tpp.weight_numel}};
120124
{%- else %}
121-
WEIGHT_T* w = weights;
122-
WEIGHT_T* wgrad = weights_grad;
125+
WEIGHT_T* w = weights;
126+
WEIGHT_T* wgrad = weights_grad;
123127
{%- endif %}
124128
WEIGHT_T* weights_shft = w + lane_id;
125129

@@ -155,8 +159,12 @@ __global__ void backward(
155159
{{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }}
156160
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}
157161

158-
{%- if not backward_schedule.stream_weights%}
159-
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
162+
{%- if not backward_schedule.stream_weights %}
163+
{%- if not tpp.shared_weights %}
164+
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
165+
{%- else %}
166+
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
167+
{%- endif %}
160168
{%- endif %}
161169
} {%- endfor %}
162170
}
@@ -332,8 +340,12 @@ __global__ void double_backward_B(
332340
{{ store_ir_segments(segment.L1Map, "l1_grad_shft", "L1_grad_smem", "j") }}
333341
{{ store_ir_segments(segment.L2Map, "l2_grad_shft", "L2_grad_smem", "j") }}
334342

335-
{% if not schedule.stream_weights%}
336-
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
343+
{% if not schedule.stream_weights %}
344+
{%- if not tpp.shared_weights %}
345+
ROW_OPERATION({{segment.problem.weight_numel}}, j, weights_grad_shft[{{segment.weight_offset}} + j] = weights_grad_smem[j + lane_id];)
346+
{%- else %}
347+
ROW_OPERATION({{segment.problem.weight_numel}}, j, atomicAdd(weights_grad_shft + {{segment.weight_offset}} + j, weights_grad_smem[j + lane_id]);)
348+
{%- endif %}
337349
{% endif %}
338350
}
339351
} {%- endfor %}

openequivariance/templates/loop_unroll_tp.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ __device__ __forceinline__ void forward_loop_unroll_{{id}}(IRREP_T* __restrict__
204204
scratch1[{{i % num_scratch_reg}}] = l3_grad[{{coord3}}] * {{value}};
205205

206206
{%- if double_bwd %}
207-
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}];
207+
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_original[{{coord2}}] * l1_vec[{{coord1}}];
208208
{%- else %}
209-
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}];
209+
weight_grad += scratch1[{{i % num_scratch_reg}}] * l2_vec[{{coord2}}] * l1_vec[{{coord1}}];
210210
{%- endif %}
211211

212212
scratch2[{{i % num_scratch_reg}}] = scratch1[{{i % num_scratch_reg}}] * weight;

tests/batch_test.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,13 @@
88
from itertools import chain, product
99

1010
class TPCorrectness:
11+
def thresh(self, direction):
12+
return {
13+
"fwd": 1e-5,
14+
"bwd": 3e-4,
15+
"double_bwd": 3e-4
16+
}[direction]
17+
1118
def check_result(self, result, fieldname):
1219
with check:
1320
error = result[fieldname]["diff_Linf_norm"]
@@ -30,7 +37,7 @@ def test_tp_fwd(self, tp_and_problem):
3037
test_implementation=tp,
3138
reference_implementation=None,
3239
batch_size=1000,
33-
correctness_threshold=1e-5,
40+
correctness_threshold=self.thresh("fwd"),
3441
prng_seed=12345)
3542

3643
self.check_result(result, "output")
@@ -42,7 +49,7 @@ def test_tp_bwd(self, tp_and_problem):
4249
test_implementation=tp,
4350
reference_implementation=None,
4451
batch_size=1000,
45-
correctness_threshold=3e-4,
52+
correctness_threshold=self.thresh("bwd"),
4653
prng_seed=12345)
4754

4855
self.check_result(result, "weight_grad")
@@ -56,7 +63,7 @@ def test_tp_double_bwd(self, tp_and_problem):
5663
test_implementation=tp,
5764
reference_implementation = None,
5865
batch_size = 200,
59-
correctness_threshold = 3e-4,
66+
correctness_threshold=self.thresh("double_bwd"),
6067
prng_seed = 12345)
6168

6269
self.check_result(result, "output_double_grad")
@@ -129,4 +136,23 @@ def problem(self, request, dtype):
129136
return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e",
130137
instructions, shared_weights=False,
131138
internal_weights=False,
132-
irrep_dtype=dtype, weight_dtype=dtype)
139+
irrep_dtype=dtype, weight_dtype=dtype)
140+
141+
142+
class TestSharedWeights(TPCorrectness):
143+
from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs
144+
problems = [mace_problems[0], diffdock_configs[0]]
145+
146+
def thresh(self, direction):
147+
return {
148+
"fwd": 1e-5,
149+
"bwd": 5e-4, # Expect higher errors for shared weights
150+
"double_bwd": 5e-4
151+
}[direction]
152+
153+
@pytest.fixture(params=problems, ids = lambda x : x.label, scope="class")
154+
def problem(self, request, dtype):
155+
problem = request.param
156+
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
157+
problem.shared_weights = True
158+
return problem

tests/conv_test.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,20 @@
77
from itertools import chain, product
88

99
class ConvCorrectness:
10+
def thresh(self, direction):
11+
return {
12+
"fwd": 1e-5,
13+
"bwd": 3e-4,
14+
"double_bwd": 3e-4
15+
}[direction]
16+
17+
1018
def check_result(self, result, fieldname):
1119
with check:
1220
error = result[fieldname]["diff_Linf_norm"]
1321
thresh = result["thresh"]
14-
assert result[fieldname]["pass"], f"{fieldname} observed error={error:.2f} >= {thresh}"
15-
22+
assert result[fieldname]["pass"], f"{fieldname} observed error={error:.5f} >= {thresh}"
23+
1624
@pytest.fixture(params=[np.float32, np.float64], ids=['F32', 'F64'], scope='class')
1725
def dtype(self, request):
1826
return request.param
@@ -48,7 +56,7 @@ def test_tp_fwd(self, conv_object, graph):
4856
return
4957

5058
result = conv_object.test_correctness_forward(graph,
51-
thresh=3e-05,
59+
thresh=self.thresh("fwd"),
5260
prng_seed=12345,
5361
reference_implementation=None)
5462

@@ -60,7 +68,7 @@ def test_tp_bwd(self, conv_object, graph):
6068
return
6169

6270
result = conv_object.test_correctness_backward(graph,
63-
thresh=3e-04,
71+
thresh=self.thresh("bwd"),
6472
prng_seed=12345,
6573
reference_implementation=None)
6674

@@ -74,7 +82,7 @@ def test_tp_double_bwd(self, conv_object, graph):
7482
return
7583

7684
result = conv_object.test_correctness_double_backward(graph,
77-
thresh=3e-04,
85+
thresh=self.thresh("double_bwd"),
7886
prng_seed=12345,
7987
reference_implementation=None)
8088

@@ -140,4 +148,27 @@ def problem(self, request, dtype):
140148
return oeq.TPProblem(f"{m[0]}x{i[0]}e", f"{m[1]}x{i[1]}e", f"{m[2]}x{i[2]}e",
141149
instructions, shared_weights=False,
142150
internal_weights=False,
143-
irrep_dtype=dtype, weight_dtype=dtype)
151+
irrep_dtype=dtype, weight_dtype=dtype)
152+
153+
154+
class TestAtomicSharedWeights(ConvCorrectness):
155+
from openequivariance.benchmark.benchmark_configs import mace_problems, diffdock_configs
156+
problems = [mace_problems[0], diffdock_configs[0]]
157+
158+
def thresh(self, direction):
159+
return {
160+
"fwd": 1e-5,
161+
"bwd": 5e-2, # Expect higher errors for shared weights
162+
"double_bwd": 5e-2
163+
}[direction]
164+
165+
@pytest.fixture(params=problems, ids = lambda x : x.label, scope="class")
166+
def problem(self, request, dtype):
167+
problem = request.param
168+
problem.irrep_dtype, problem.weight_dtype = dtype, dtype
169+
problem.shared_weights = True
170+
return problem
171+
172+
@pytest.fixture(scope='class')
173+
def conv_object(self, request, problem):
174+
return oeq.TensorProductConv(problem, deterministic=False)

0 commit comments

Comments
 (0)