Skip to content

Commit 371e2d3

Browse files
vbharadwaj-bkAustin Glover
andauthored
AMP Autocast Registration (#150)
* python autocast opt out * add irrep_dtype * loop tp changes * LoopUnrollChanges * test commit * remove test comment * put autocast registration behind guard * Preparing Austin's commit for autocast merge. * Linted. --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent f76b21e commit 371e2d3

3 files changed

Lines changed: 71 additions & 5 deletions

File tree

openequivariance/extension/libtorch_tp_jit.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,7 @@ class TorchJITConv : public torch::CustomClassHolder {
342342
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
343343
JITConvImpl<JITKernel> internal;
344344
KernelProp kernelProp;
345-
int64_t L3_dim;
345+
int64_t L3_dim, irrep_dtype;
346346

347347
TorchJITConv(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
348348
fwd_dict(fwd_dict_i.copy()),
@@ -356,7 +356,8 @@ class TorchJITConv : public torch::CustomClassHolder {
356356
to_map(kernel_dims_i)
357357
),
358358
kernelProp(kernel_dims, true),
359-
L3_dim(kernelProp.L3_dim)
359+
L3_dim(kernelProp.L3_dim),
360+
irrep_dtype(kernel_dims_i.at("irrep_dtype"))
360361
{ }
361362

362363
tuple<tuple<string, string>,
@@ -676,6 +677,7 @@ TORCH_LIBRARY_FRAGMENT(libtorch_tp_jit, m) {
676677
return 0;
677678
})
678679
.def_readonly("L3_dim", &TorchJITConv::L3_dim)
680+
.def_readonly("irrep_dtype", &TorchJITConv::irrep_dtype)
679681
.def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool {
680682
return self.is(other);
681683
})

openequivariance/implementations/LoopUnrollTP.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,21 @@ def double_backward(ctx, E, F, G):
232232
setup_context=setup_context_double_backward,
233233
)
234234

235+
@classmethod
236+
def register_autocast(cls):
237+
global torch
238+
import torch
239+
240+
torch.library.register_autocast(
241+
"libtorch_tp_jit::jit_tp_forward", "cuda", torch.float32
242+
)
243+
torch.library.register_autocast(
244+
"libtorch_tp_jit::jit_tp_backward", "cuda", torch.float32
245+
)
246+
torch.library.register_autocast(
247+
"libtorch_tp_jit::jit_tp_double_backward", "cuda", torch.float32
248+
)
249+
235250
@staticmethod
236251
def name():
237252
return "LoopUnrollTP"
@@ -290,3 +305,4 @@ def calculate_flops_backward(self, batch_size: int) -> dict:
290305
if extlib.TORCH_COMPILE:
291306
LoopUnrollTP.register_torch_fakes()
292307
LoopUnrollTP.register_autograd()
308+
LoopUnrollTP.register_autocast()

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
SMEMCapacityException,
77
)
88

9-
from openequivariance.implementations.dtype_enum import dtype_to_enum
9+
from openequivariance.implementations.dtype_enum import (
10+
dtype_to_enum,
11+
enum_to_torch_dtype,
12+
)
1013
from openequivariance.templates.jinja_utils import get_jinja_environment
1114
from openequivariance import extlib
1215
from openequivariance.extlib import JITConvImpl, postprocess_kernel, DeviceProp
@@ -297,20 +300,49 @@ def double_backward_rawptrs(*args, **kwargs):
297300
def fake_forward(
298301
jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm
299302
):
300-
L3_dim = None
303+
L3_dim, irrep_dtype = None, None
301304
if hasattr(jit, "wrapped_obj"):
302305
L3_dim = jit.wrapped_obj.kernel_dims["L3_dim"]
306+
irrep_dtype = jit.wrapped_obj.kernel_dims["irrep_dtype"]
303307
else:
304308
L3_dim = jit.L3_dim
309+
irrep_dtype = jit.irrep_dtype
305310

306-
return L1_in.new_empty(L1_in.shape[0], L3_dim)
311+
return torch.empty(
312+
L1_in.shape[0],
313+
L3_dim,
314+
device="cuda",
315+
dtype=enum_to_torch_dtype[irrep_dtype],
316+
)
307317

308318
@torch.library.register_fake("libtorch_tp_jit::jit_conv_backward")
309319
def fake_backward(
310320
jit, L1_in, L2_in, W, L3_grad, rows, cols, workspace_buffer, sender_perm
311321
):
312322
return torch.empty_like(L1_in), torch.empty_like(L2_in), torch.empty_like(W)
313323

324+
@torch.library.register_fake("libtorch_tp_jit::jit_conv_double_backward")
325+
def fake_double_backward(
326+
jit,
327+
L1_in,
328+
L2_in,
329+
W,
330+
L3_grad,
331+
L1_dgrad,
332+
L2_dgrad,
333+
w_dgrad,
334+
rows,
335+
cols,
336+
workspace_buffer,
337+
transpose_perm=None,
338+
):
339+
return [
340+
L1_in.new_empty(*L1_in.shape),
341+
L2_in.new_empty(*L2_in.shape),
342+
W.new_empty(*W.shape),
343+
L3_grad.new_empty(*L3_grad.shape),
344+
]
345+
314346
@classmethod
315347
def register_autograd(cls):
316348
backward_op = torch.ops.libtorch_tp_jit.jit_conv_backward
@@ -393,7 +425,23 @@ def double_backward(ctx, E, F, G):
393425
setup_context=setup_context_double_backward,
394426
)
395427

428+
@classmethod
429+
def register_autocast(cls):
430+
global torch
431+
import torch
432+
433+
torch.library.register_autocast(
434+
"libtorch_tp_jit::jit_conv_forward", "cuda", torch.float32
435+
)
436+
torch.library.register_autocast(
437+
"libtorch_tp_jit::jit_conv_backward", "cuda", torch.float32
438+
)
439+
torch.library.register_autocast(
440+
"libtorch_tp_jit::jit_conv_double_backward", "cuda", torch.float32
441+
)
442+
396443

397444
if extlib.TORCH_COMPILE:
398445
LoopUnrollConv.register_torch_fakes()
399446
LoopUnrollConv.register_autograd()
447+
LoopUnrollConv.register_autocast()

0 commit comments

Comments
 (0)