Skip to content

Commit ce68f69

Browse files
committed
Forward call is working.
1 parent 0d07cd9 commit ce68f69

2 files changed

Lines changed: 16 additions & 23 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
logger = getLogger()
1515

1616
#@partial(jax.custom_vjp, nondiff_argnums=(3,4,5,6,7,8,9))
17-
def forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs):
17+
def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
1818
forward_call = jax.ffi.ffi_call("conv_forward",
1919
jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype))
20-
return forward_call(X, Y, W, rows, cols, sender_perm, workspace, **attrs)
20+
return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs)
2121

22-
def forward_with_inputs(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs):
23-
return forward(X, Y, W, rows, cols, sender_perm, workspace, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace)
22+
def forward_with_inputs(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
23+
return forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, rows, cols, sender_perm, workspace)
2424

2525
class TensorProductConv(LoopUnrollConv):
2626
def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False):
@@ -67,8 +67,9 @@ def forward(
6767

6868
return forward(
6969
X, Y, W,
70-
rows, cols, sender_perm,
70+
rows, cols,
7171
self.workspace,
72+
sender_perm,
7273
self.L3_dim,
7374
self.config.irrep_dtype,
7475
self.attrs)

openequivariance_extjax/src/libjax_tp_jit.cpp

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,20 @@ xla::ffi::DataType enum_to_xla_dtype(int64_t i){
4646
}
4747

4848
inline void* data_ptr(ffi::AnyBuffer &buffer) {
49-
switch (buffer.element_type()) {
50-
case xla::ffi::DataType::F32:
51-
return reinterpret_cast<void*>(buffer.typed_data<float>());
52-
case xla::ffi::DataType::F64:
53-
return reinterpret_cast<void*>(buffer.typed_data<double>());
54-
case xla::ffi::DataType::S64:
55-
return reinterpret_cast<void*>(buffer.typed_data<int64_t>());
56-
case xla::ffi::DataType::U8:
57-
return reinterpret_cast<void*>(buffer.typed_data<uint8_t>());
58-
default:
59-
throw logic_error("Unsupported tensor datatype!");
60-
}
49+
return buffer.untyped_data();
50+
}
51+
52+
inline void* data_ptr(ffi::Result<ffi::AnyBuffer> &buffer) {
53+
return data_ptr(*buffer);
6154
}
6255

6356
inline int byte_count(ffi::AnyBuffer &buffer) {
6457
switch (buffer.element_type()) {
58+
case xla::ffi::DataType::U32:
59+
case xla::ffi::DataType::S32:
6560
case xla::ffi::DataType::F32:
6661
return 4;
6762
case xla::ffi::DataType::F64:
68-
return 8;
6963
case xla::ffi::DataType::S64:
7064
return 8;
7165
case xla::ffi::DataType::U8:
@@ -75,10 +69,6 @@ inline int byte_count(ffi::AnyBuffer &buffer) {
7569
}
7670
}
7771

78-
inline void* data_ptr(ffi::Result<ffi::AnyBuffer> &buffer) {
79-
return data_ptr(*buffer);
80-
}
81-
8272
#ifdef CUDA_BACKEND
8373
void zero_buffer(ffi::AnyBuffer &buffer) {
8474
cudaMemset(
@@ -245,7 +235,9 @@ inline void check_tensor(const ffi::AnyBuffer &buffer,
245235
}
246236

247237
if (buffer.element_type() != expected_dtype) {
248-
throw std::logic_error("Datatype mismatch for tensor " + tensor_name);
238+
throw std::logic_error("Datatype mismatch for tensor " + tensor_name +
239+
". Expected datatype " + std::to_string(static_cast<int64_t>(expected_dtype)) +
240+
", got " + std::to_string(static_cast<int64_t>(buffer.element_type())));
249241
}
250242
}
251243

0 commit comments

Comments
 (0)