Skip to content

Commit 1f9c53e

Browse files
Accelerated double backward pass for convolution. (#95)
* Making progress on double backward pass. * More progress on double backward setup. * Custom double backward pass is compiling. * Wrote the A kernel. * Atomic double backward pass accelerated. * More convolution double backward tests passing. * Deterministic double backward A kernel written. * Double backward deterministic is compiling, but we need to check the output. * Double backward A kernel is working. * Gradient of L1 has an error, need to fix. * All tests passing. * Updated conv test fixture to download graph again.
1 parent 6be2be5 commit 1f9c53e

9 files changed

Lines changed: 577 additions & 35 deletions

File tree

openequivariance/extension/convolution.hpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,50 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
9797
JIT_IMPL jit;
9898
KernelLaunchConfig forward_config;
9999
KernelLaunchConfig backward_config;
100+
KernelLaunchConfig double_backward_config;
100101
bool is_uvw;
101102

102103
JITConvImpl(
103104
std::string jit_kernel,
104105
KernelLaunchConfig forward_config_i,
105106
KernelLaunchConfig backward_config_i,
107+
KernelLaunchConfig double_backward_config_i,
106108
bool is_uvw_i) :
107109
jit(jit_kernel),
108110
forward_config(forward_config_i),
109111
backward_config(backward_config_i),
112+
double_backward_config(double_backward_config_i),
110113
is_uvw(is_uvw_i) {
111114

112-
vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward"};
115+
vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward", "double_backward_A", "double_backward_B", "fixup_double_backwardB"};
113116

114117
int opt_level = 3;
115118
#ifdef HIP_BACKEND
116119
if(is_uvw) {
117120
opt_level = 1;
118121
}
119122
#endif
120-
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
123+
jit.compile(kernels, {{}, {}, {}, {}, {}, {}, {}}, opt_level);
121124

122125
if(forward_config.smem > 0) {
123126
jit.set_max_smem(0, forward_config.smem);
127+
jit.set_max_smem(4, forward_config.smem);
124128
}
125129

126130
if(backward_config.smem > 0) {
127131
jit.set_max_smem(1, backward_config.smem);
128132
}
133+
134+
if(double_backward_config.smem > 0) {
135+
jit.set_max_smem(5, double_backward_config.smem);
136+
}
129137
}
130138

131139
JITConvImpl(
132140
std::string jit_kernel,
133141
std::unordered_map<string, int64_t> fwd_dict,
134142
std::unordered_map<string, int64_t> bwd_dict,
143+
std::unordered_map<string, int64_t> dbl_bwd_dict,
135144
std::unordered_map<string, int64_t> kernel_dims
136145
) : JITConvImpl(
137146
jit_kernel,
@@ -145,6 +154,11 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
145154
bwd_dict["num_threads"],
146155
bwd_dict["smem"]
147156
),
157+
KernelLaunchConfig(
158+
dbl_bwd_dict["num_blocks"],
159+
dbl_bwd_dict["num_threads"],
160+
dbl_bwd_dict["smem"]
161+
),
148162
kernel_dims["is_uvw"] == 1) { }
149163

150164
void exec_conv(
@@ -201,5 +215,40 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
201215
}
202216
}
203217

218+
void double_backward(
219+
void* L1_in, void* L2_in, void* W, void* L3_grad,
220+
void* L1_dgrad, void* L2_dgrad, void* w_dgrad,
221+
void* L1_grad, void* L2_grad, void* W_grad, void* L3_dgrad,
222+
void* rows, void* cols,
223+
uint64_t nnz, uint64_t node_count,
224+
void* wspace, void* transpose_perm) {
225+
226+
ConvData conv_data = {rows, cols, nnz, node_count};
227+
void* args[] = {
228+
&L1_in, &L2_in, &W, &L3_grad, &L1_dgrad, &L2_dgrad, &w_dgrad,
229+
&L1_grad, &L2_grad, &W_grad, &L3_dgrad, &conv_data, &wspace, &transpose_perm
230+
};
231+
232+
jit.execute(4, args, forward_config);
233+
if(reinterpret_cast<uint64_t>(wspace) != 0) {
234+
void *fixup_args[] = {&wspace, &L3_dgrad};
235+
KernelLaunchConfig fixup_config;
236+
fixup_config.num_blocks = forward_config.num_blocks;
237+
fixup_config.num_threads = forward_config.num_threads;
238+
fixup_config.smem = 0;
239+
jit.execute(2, fixup_args, fixup_config);
240+
}
241+
242+
jit.execute(5, args, double_backward_config);
243+
if(reinterpret_cast<uint64_t>(wspace) != 0) {
244+
void *fixup_args[] = {&wspace, &L1_grad};
245+
KernelLaunchConfig fixup_config;
246+
fixup_config.num_blocks = double_backward_config.num_blocks;
247+
fixup_config.num_threads = double_backward_config.num_threads;
248+
fixup_config.smem = 0;
249+
jit.execute(6, fixup_args, fixup_config);
250+
}
251+
}
252+
204253
~JITConvImpl() = default;
205254
};

openequivariance/extension/generic_module.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ PYBIND11_MODULE(generic_module, m) {
4949
.def("backward_rawptrs", &ConvolutionImpl::backward_rawptrs);
5050
py::class_<JITConvImpl<JITKernel>, ConvolutionImpl>(m, "JITConvImpl")
5151
.def(py::init< std::string,
52+
std::unordered_map<string, int64_t>,
5253
std::unordered_map<string, int64_t>,
5354
std::unordered_map<string, int64_t>,
5455
std::unordered_map<string, int64_t>>());

openequivariance/extension/torch_tp_jit.cpp

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -215,25 +215,32 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_tp_double_
215215

216216
class TorchJITConv : public torch::CustomClassHolder {
217217
public:
218-
Map_t fwd_dict, bwd_dict, kernel_dims;
218+
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
219219
JITConvImpl<JITKernel> internal;
220220
int64_t L3_dim;
221221

222-
TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t kernel_dims_i) :
222+
TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
223223
fwd_dict(fwd_dict_i.copy()),
224224
bwd_dict(bwd_dict_i.copy()),
225+
dbl_bwd_dict(bwd_dict_i.copy()),
225226
kernel_dims(kernel_dims_i.copy()),
226227
internal(kernel_plaintext,
227228
to_map(fwd_dict_i),
228229
to_map(bwd_dict_i),
230+
to_map(dbl_bwd_dict_i),
229231
to_map(kernel_dims_i)
230232
),
231233
L3_dim(kernel_dims.at("L3_dim")) { }
232234

233-
tuple<tuple<string, string>, tuple<string, Map_t>, tuple<string, Map_t>, tuple<string, Map_t>> __obj_flatten__() {
235+
tuple<tuple<string, string>,
236+
tuple<string, Map_t>,
237+
tuple<string, Map_t>,
238+
tuple<string, Map_t>,
239+
tuple<string, Map_t>> __obj_flatten__() {
234240
return tuple(tuple("kernel_plaintext", internal.jit.kernel_plaintext),
235241
tuple("fwd_config", fwd_dict),
236242
tuple("bwd_config", bwd_dict),
243+
tuple("dbl_bwd_config", dbl_bwd_dict),
237244
tuple("kernel_dims", kernel_dims));
238245
}
239246

@@ -347,6 +354,55 @@ tuple<torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_backward(
347354
return tuple(L1_grad, L2_grad, W_grad);
348355
}
349356

357+
tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> jit_conv_double_backward(
358+
const c10::intrusive_ptr<TorchJITConv> &jit_instance,
359+
const torch::Tensor &L1_in,
360+
const torch::Tensor &L2_in,
361+
const torch::Tensor &W,
362+
const torch::Tensor &L3_grad,
363+
const torch::Tensor &L1_dgrad,
364+
const torch::Tensor &L2_dgrad,
365+
const torch::Tensor &W_dgrad,
366+
const torch::Tensor &rows,
367+
const torch::Tensor &cols,
368+
const torch::Tensor &workspace,
369+
const torch::Tensor &transpose_perm) {
370+
371+
int64_t nnz = rows.sizes()[0];
372+
int64_t node_count = L1_in.sizes()[0];
373+
torch::Tensor L1_grad = torch::zeros(L1_in.sizes(), L1_in.options());
374+
torch::Tensor L2_grad = torch::empty(L2_in.sizes(), L2_in.options());
375+
torch::Tensor W_grad = torch::empty(W.sizes(), W.options());
376+
torch::Tensor L3_dgrad = torch::zeros(L3_grad.sizes(), L3_grad.options());
377+
378+
torch::Tensor L1_in_contig = L1_in.contiguous();
379+
torch::Tensor L2_in_contig = L2_in.contiguous();
380+
torch::Tensor W_contig = W.contiguous();
381+
torch::Tensor L3_grad_contig = L3_grad.contiguous();
382+
torch::Tensor L1_dgrad_contig = L1_dgrad.contiguous();
383+
torch::Tensor L2_dgrad_contig = L2_dgrad.contiguous();
384+
torch::Tensor W_dgrad_contig = W_dgrad.contiguous();
385+
386+
torch::Tensor rows_contig = rows.contiguous();
387+
torch::Tensor cols_contig = cols.contiguous();
388+
torch::Tensor workspace_contig = workspace.contiguous();
389+
torch::Tensor transpose_perm_contig = transpose_perm.contiguous();
390+
391+
jit_instance->internal.double_backward(
392+
data_ptr(L1_in_contig), data_ptr(L2_in_contig),
393+
data_ptr(W_contig), data_ptr(L3_grad_contig),
394+
data_ptr(L1_dgrad_contig), data_ptr(L2_dgrad_contig),
395+
data_ptr(W_dgrad_contig),
396+
data_ptr(L1_grad), data_ptr(L2_grad),
397+
data_ptr(W_grad), data_ptr(L3_dgrad),
398+
data_ptr(rows_contig), data_ptr(cols_contig),
399+
nnz, node_count,
400+
data_ptr(workspace_contig), data_ptr(transpose_perm_contig)
401+
);
402+
403+
return tuple(L1_grad, L2_grad, W_grad, L3_dgrad);
404+
}
405+
350406
// ===========================================================
351407

352408
TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
@@ -376,7 +432,7 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
376432

377433

378434
m.class_<TorchJITConv>("TorchJITConv")
379-
.def(torch::init<string, Map_t, Map_t, Map_t>())
435+
.def(torch::init<string, Map_t, Map_t, Map_t, Map_t>())
380436
.def("__obj_flatten__", &TorchJITConv::__obj_flatten__)
381437
.def("exec_conv_rawptrs", &TorchJITConv::exec_conv_rawptrs)
382438
.def("backward_rawptrs", &TorchJITConv::backward_rawptrs)
@@ -386,27 +442,28 @@ TORCH_LIBRARY_FRAGMENT(torch_tp_jit, m) {
386442
.def_pickle(
387443
// __getstate__
388444
[](const c10::intrusive_ptr<TorchJITConv>& self)
389-
-> tuple<string, Map_t, Map_t, Map_t> {
390-
return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->kernel_dims);
445+
-> tuple<string, Map_t, Map_t, Map_t, Map_t> {
446+
return tuple(self->internal.jit.kernel_plaintext, self->fwd_dict, self->bwd_dict, self->dbl_bwd_dict, self->kernel_dims);
391447
},
392448
// __setstate__
393-
[](tuple<string, Map_t, Map_t, Map_t> state)
449+
[](tuple<string, Map_t, Map_t, Map_t, Map_t> state)
394450
-> c10::intrusive_ptr<TorchJITConv> {
395-
return c10::make_intrusive<TorchJITConv>(get<0>(state), get<1>(state), get<2>(state), get<3>(state));
451+
return c10::make_intrusive<TorchJITConv>(get<0>(state), get<1>(state), get<2>(state), get<3>(state), get<4>(state));
396452
});
397453

398454
m.def("jit_conv_forward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> Tensor");
399455
m.def("jit_conv_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor)");
456+
m.def("jit_conv_double_backward(__torch__.torch.classes.torch_tp_jit.TorchJITConv jit, Tensor L1_in, Tensor L2_in, Tensor W, Tensor L3_grad, Tensor L1_dgrad, Tensor L2_dgrad, Tensor W_dgrad, Tensor rows, Tensor cols, Tensor workspace, Tensor transpose_perm) -> (Tensor, Tensor, Tensor, Tensor)");
400457
};
401458

402-
403459
TORCH_LIBRARY_IMPL(torch_tp_jit, CUDA, m) {
404460
m.impl("jit_tp_forward", &jit_tp_forward);
405461
m.impl("jit_tp_backward", &jit_tp_backward);
406462
m.impl("jit_tp_double_backward", &jit_tp_double_backward);
407463

408464
m.impl("jit_conv_forward", &jit_conv_forward);
409465
m.impl("jit_conv_backward", &jit_conv_backward);
466+
m.impl("jit_conv_double_backward", &jit_conv_double_backward);
410467
};
411468

412469
PYBIND11_MODULE(torch_tp_jit, m) {}

openequivariance/implementations/ComputationSchedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def __init__(self,
276276

277277
smem_limit -= 1
278278
self.memory_per_warp = smem_limit // warps_per_block
279-
self.memory_per_warp -= self.memory_per_warp % 4
279+
self.memory_per_warp -= self.memory_per_warp % 8
280280

281281
# =====================================================================
282282
# Shared memory partitioning functions

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,21 @@ def generate_backward_schedule(warps_per_block):
4949
warp_size=dp.warpsize,
5050
include_scratch=self.is_uvw,
5151
stream_weights=self.is_uvw)
52+
53+
def generate_double_backward_schedule(warps_per_block):
54+
self.double_backward_schedule = ComputationSchedule(self.config,
55+
smem_limit=dp.maxSharedMemPerBlock,
56+
warps_per_block=warps_per_block,
57+
warp_size=dp.warpsize,
58+
block_count=dp.multiprocessorCount,
59+
direction = "double_backward",
60+
irrep_dtype = config.irrep_dtype,
61+
weight_dtype = config.weight_dtype,
62+
include_scratch=self.is_uvw,
63+
stream_weights=self.is_uvw,
64+
schedule_type=3)
5265

53-
scheduler_generators = [generate_forward_schedule, generate_backward_schedule]
66+
scheduler_generators = [generate_forward_schedule, generate_backward_schedule, generate_double_backward_schedule]
5467

5568
for generate_schedule in scheduler_generators:
5669
warp_count = 6
@@ -72,35 +85,47 @@ def generate_backward_schedule(warps_per_block):
7285
for key in segment.L1Map.storeback_procedure:
7386
segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
7487

88+
for segment in self.double_backward_schedule.segments:
89+
for key in segment.L1Map.storeback_procedure:
90+
segment.L1Map.storeback_procedure[key] = "atomic_accumulate"
91+
92+
7593
idx_type_map = {np.int32: "int", np.int64: "long"}
7694

7795
if self.torch_op:
7896
self.setup_torch_module()
7997

8098
self.forward_workspace_offset = None
8199
self.backward_workspace_offset = None
100+
self.double_backwardB_offset = None
82101

83102
workspace_size = 1
84103
if deterministic:
85104
destination_index_bytes = 32 # Add extra to account for padding
86105
workspace_size = max(
87106
(self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.forward_schedule.total_warps,
88-
(self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps)
107+
(self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.backward_schedule.total_warps,
108+
(self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize + destination_index_bytes) * self.double_backward_schedule.total_warps
109+
)
89110

90111
self.forward_workspace_offset = self.forward_schedule.L3.dim * np.dtype(config.irrep_dtype).itemsize * self.forward_schedule.total_warps
91112
self.backward_workspace_offset = self.backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.backward_schedule.total_warps
113+
self.double_backwardB_offset = self.double_backward_schedule.L1.dim * np.dtype(config.irrep_dtype).itemsize * self.double_backward_schedule.total_warps
92114

93115
self.forward_workspace_offset = (self.forward_workspace_offset + 7) // 8 * 8
94116
self.backward_workspace_offset = (self.backward_workspace_offset + 7) // 8 * 8
117+
self.double_backwardB_offset = (self.double_backwardB_offset + 7) // 8 * 8
95118

96119
self.allocate_workspace(workspace_size)
97120

98121
self.jit_kernel = template.render(
99122
forward_schedule=self.forward_schedule,
100123
backward_schedule=self.backward_schedule,
124+
double_backward_schedule=self.double_backward_schedule,
101125
idx_type=idx_type_map[idx_dtype],
102126
forward_workspace_offset=self.forward_workspace_offset,
103-
backward_workspace_offset=self.backward_workspace_offset)
127+
backward_workspace_offset=self.backward_workspace_offset,
128+
double_backwardB_offset=self.double_backwardB_offset)
104129
self.jit_kernel = postprocess_kernel(self.jit_kernel)
105130

106131
if self.torch_op and extlib.TORCH_COMPILE:
@@ -115,6 +140,7 @@ def generate_backward_schedule(warps_per_block):
115140
self.internal = internal_cls(self.jit_kernel,
116141
vars(self.forward_schedule.launch_config),
117142
vars(self.backward_schedule.launch_config),
143+
vars(self.double_backward_schedule.launch_config),
118144
{"L3_dim": self.L3.dim,
119145
"is_uvw": int(self.is_uvw)})
120146
logger.info("Kernel compiled!")
@@ -137,9 +163,10 @@ def register_torch_fakes(cls):
137163
class TorchJITConv:
138164
def __init__(self, kernel_plaintext: str,
139165
fwd_config: dict[str, int],
140-
bwd_config: dict[str, int],
166+
bwd_config: dict[str, int],
167+
dbl_bwd_config: dict[str, int],
141168
kernel_dims: dict[str, int]) -> None:
142-
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, kernel_dims
169+
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = kernel_plaintext, fwd_config, bwd_config, dbl_bwd_config, kernel_dims
143170

144171
@classmethod
145172
def __obj_unflatten__(cls, flattened_product):
@@ -149,7 +176,7 @@ def __len__(self):
149176
return 0
150177

151178
def __setstate__(self, state):
152-
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.kernel_dims = state
179+
self.kernel_plaintext, self.fwd_config, self.bwd_config, self.dbl_bwd_config, self.kernel_dims = state
153180

154181
@torch.library.register_fake("torch_tp_jit::jit_conv_forward")
155182
def fake_forward(jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm):
@@ -163,6 +190,7 @@ def fake_backward(jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, s
163190
def register_autograd(cls):
164191
forward_op = torch.ops.torch_tp_jit.jit_conv_forward
165192
backward_op = torch.ops.torch_tp_jit.jit_conv_backward
193+
double_backward_op = torch.ops.torch_tp_jit.jit_conv_double_backward
166194

167195
def setup_context(ctx, inputs, output):
168196
ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm = inputs
@@ -178,17 +206,8 @@ def setup_context_double_backward(ctx, inputs, output):
178206
ctx.inputs = inputs
179207

180208
def double_backward(ctx, E, F, G):
181-
jit, A, B, C, D, rows, cols, wspace, sender_perm = ctx.jit, ctx.L1_in, ctx.L2_in, ctx.grad_output, ctx.W, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm
182-
183-
op1 = backward_op(jit, E, F, D, C, rows, cols, wspace, sender_perm)
184-
op2 = backward_op(jit, A, B, G, C, rows, cols, wspace, sender_perm)
185-
op3 = forward_op(jit, E, B, D, rows, cols, wspace, sender_perm)
186-
op4 = backward_op(jit, E, B, D, C, rows, cols, wspace, sender_perm) # op4 and op5 could be combined with op3 and op6
187-
op5 = backward_op(jit, A, F, D, C, rows, cols, wspace, sender_perm)
188-
op6 = forward_op(jit, A, F, D, rows, cols, wspace, sender_perm)
189-
op7 = forward_op(jit, A, B, G, rows, cols, wspace, sender_perm)
190-
191-
return None, op1[0] + op2[0], op1[1] + op2[1], op4[2] + op5[2], (op3 + op6 + op7), None, None, None, None
209+
result = double_backward_op(ctx.jit, ctx.L1_in, ctx.L2_in, ctx.W, ctx.grad_output, E, F, G, ctx.rows, ctx.cols, ctx.workspace_buffer, ctx.sender_perm)
210+
return None, result[0], result[1], result[2], result[3], None, None, None, None
192211

193212
torch.library.register_autograd("torch_tp_jit::jit_conv_backward", double_backward, setup_context=setup_context_double_backward)
194213

0 commit comments

Comments
 (0)