Skip to content

Commit 0d07cd9

Browse files
committed
More plumbing.
1 parent 19f284b commit 0d07cd9

3 files changed

Lines changed: 21 additions & 11 deletions

File tree

openequivariance/openequivariance/core/LoopUnrollTP.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ def generate_double_backward_schedule(warps_per_block):
9999
"opt_level": 3,
100100
"irrep_dtype": dtype_to_enum[self.config.irrep_dtype],
101101
"weight_dtype": dtype_to_enum[self.config.weight_dtype],
102+
103+
# Not relevant, included for compatibility with convolution
104+
"workspace_size": 0,
105+
"deterministic": 1,
106+
"idx_dtype": 0
102107
}
103108

104109

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from openequivariance.benchmark.logging_utils import getLogger
1414
logger = getLogger()
1515

16-
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
16+
#@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
1717
def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs):
1818
forward_call = jax.ffi.ffi_call("conv_forward",
1919
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
@@ -53,17 +53,17 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool =
5353

5454
def forward(
5555
self,
56-
X: jax.ndarray,
57-
Y: jax.ndarray,
58-
W: jax.ndarray,
59-
rows: jax.ndarray,
60-
cols: jax.ndarray,
61-
sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray:
56+
X: jax.numpy.ndarray,
57+
Y: jax.numpy.ndarray,
58+
W: jax.numpy.ndarray,
59+
rows: jax.numpy.ndarray,
60+
cols: jax.numpy.ndarray,
61+
sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray:
6262

63-
if self.deterministic:
63+
if not self.deterministic:
6464
sender_perm = self.dummy_transpose_perm
6565
else:
66-
assert sender_perm is not None, "Must provide sender_perm for non-deterministic convolutions."
66+
assert sender_perm is not None, "Must provide sender_perm for deterministic convolutions."
6767

6868
return forward(
6969
X, Y, W,

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,12 @@ std::vector<std::string> kernel_prop_keys = {
144144
"shared_weights",
145145
"opt_level",
146146
"irrep_dtype",
147-
"weight_dtype"};
147+
"weight_dtype",
148+
149+
// Convolution only
150+
"workspace_size",
151+
"deterministic",
152+
"idx_dtype"};
148153

149154
std::unordered_map<string, int64_t> parse_ffi_dict(ffi::Dictionary &dict, const std::vector<string> &keys) {
150155
std::unordered_map<string, int64_t> result;
@@ -240,7 +245,7 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
240245
}
241246

242247
if (buffer.element_type() != expected_dtype) {
243-
throw std::logic_error("Datatype mismatch.");
248+
throw std::logic_error("Datatype mismatch for tensor " + tensor_name);
244249
}
245250
}
246251

0 commit comments

Comments
 (0)