Skip to content

Commit 4df7dd6

Browse files
asgloverAustin Glover
andauthored
Adding "_getter" functions in python for TorchBind warnings (#152)
* expose irrep dtype for tensor product * add getters to fake registration * add getters to fake registration for conv --------- Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 371e2d3 commit 4df7dd6

3 files changed

Lines changed: 14 additions & 3 deletions

File tree

openequivariance/extension/libtorch_tp_jit.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::C
134134
Map_t fwd_dict, bwd_dict, dbl_bwd_dict, kernel_dims;
135135
JITTPImpl<JITKernel> internal;
136136
KernelProp kernelProp;
137-
int64_t L3_dim;
137+
int64_t L3_dim, irrep_dtype;
138138

139139
TorchJITProduct(string kernel_plaintext, Map_t fwd_dict_i, Map_t bwd_dict_i, Map_t dbl_bwd_dict_i, Map_t kernel_dims_i) :
140140
fwd_dict(fwd_dict_i.copy()),
@@ -148,7 +148,8 @@ class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::C
148148
to_map(kernel_dims_i)
149149
),
150150
kernelProp(kernel_dims, false),
151-
L3_dim(kernelProp.L3_dim)
151+
L3_dim(kernelProp.L3_dim),
152+
irrep_dtype(kernel_dims_i.at("irrep_dtype"))
152153
{ }
153154

154155
tuple< tuple<string, string>,
@@ -647,6 +648,7 @@ TORCH_LIBRARY_FRAGMENT(libtorch_tp_jit, m) {
647648
return 0;
648649
})
649650
.def_readonly("L3_dim", &TorchJITProduct::L3_dim)
651+
.def_readonly("irrep_dtype", &TorchJITProduct::irrep_dtype)
650652
.def("__eq__", [](const c10::IValue & self, const c10::IValue& other) -> bool {
651653
return self.is(other);
652654
})

openequivariance/implementations/LoopUnrollTP.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,12 @@ def exec_tensor_product_rawptr(*args, **kwargs):
183183
def backward_rawptr(*args, **kwargs):
184184
pass
185185

186-
def get_L3_dim(self):
186+
def L3_dim_getter(self):
187187
return self.kernel_dims["L3_dim"]
188188

189+
def irrep_dtype_getter(self):
190+
return self.kernel_dims["irrep_dtype"]
191+
189192
@torch.library.register_fake("libtorch_tp_jit::jit_tp_forward")
190193
def fake_forward(jit, L1_in, L2_in, W):
191194
L3_dim = None

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,12 @@ def backward_rawptrs(*args, **kwargs):
296296
def double_backward_rawptrs(*args, **kwargs):
297297
pass
298298

299+
def L3_dim_getter(self):
300+
return self.kernel_dims["L3_dim"]
301+
302+
def irrep_dtype_getter(self):
303+
return self.kernel_dims["irrep_dtype"]
304+
299305
@torch.library.register_fake("libtorch_tp_jit::jit_conv_forward")
300306
def fake_forward(
301307
jit, L1_in, L2_in, W, rows, cols, workspace_buffer, sender_perm

0 commit comments

Comments
 (0)