Skip to content

Commit 6452140

Browse files
committed
Added JAX reordering function.
1 parent d924503 commit 6452140

3 files changed

Lines changed: 72 additions & 4 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from openequivariance.core.e3nn_lite import TPProblem
66
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
77
from openequivariance.core.utils import hash_attributes
8-
8+
from openequivariance.impl_jax.utils import reorder_jax
99

1010
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5))
1111
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
@@ -82,7 +82,13 @@ def __call__(
8282
self, X: jax.numpy.ndarray, Y: jax.numpy.ndarray, W: jax.numpy.ndarray
8383
) -> jax.numpy.ndarray:
8484
return self.forward(X, Y, W)
85-
85+
86+
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
87+
return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights)
88+
89+
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
90+
return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights)
91+
8692
def forward_cpu(self, L1_in, L2_in, L3_out, weights) -> None:
8793
result = self.forward(
8894
jax.numpy.asarray(L1_in),
@@ -105,5 +111,4 @@ def backward_cpu(
105111
)
106112
L1_grad[:] = np.asarray(L1_grad_jax)
107113
L2_grad[:] = np.asarray(L2_grad_jax)
108-
weights_grad[:] = np.asarray(weights_grad_jax)
109-
114+
weights_grad[:] = np.asarray(weights_grad_jax)

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from openequivariance.core.e3nn_lite import TPProblem
77
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
88
from openequivariance.core.utils import hash_attributes
9+
from openequivariance.impl_jax.utils import reorder_jax
910

1011
import jax
1112
import jax.numpy as jnp
@@ -163,6 +164,12 @@ def __call__(
163164
) -> jax.numpy.ndarray:
164165
return self.forward(X, Y, W, rows, cols, sender_perm)
165166

167+
def reorder_weights_from_e3nn(self, weights, has_batch_dim=True):
168+
return reorder_jax(self.forward_schedule, weights, "forward", not self.config.shared_weights)
169+
170+
def reorder_weights_to_e3nn(self, weights, has_batch_dim=True):
171+
return reorder_jax(self.forward_schedule, weights, "backward", not self.config.shared_weights)
172+
166173
def forward_cpu(self, L1_in, L2_in, weights, L3_out, graph):
167174
rows = graph.rows.astype(np.int32)
168175
cols = graph.cols.astype(np.int32)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import jax
2+
import jax.numpy as jnp
3+
import numpy as np
4+
5+
def reorder_jax_helper(schedule, weights_in, direction, has_batch_dim):
6+
assert direction in ["forward", "backward"]
7+
8+
specs = schedule.weight_reordering_info(weights_in, has_batch_dim)
9+
weights_out = jnp.zeros_like(weights_in)
10+
11+
for spec in specs:
12+
parent_range = spec["parent_range"]
13+
parent_shape = spec["parent_shape"]
14+
weights_subrange = spec["weights_subrange"]
15+
child_range = spec["child_range"]
16+
transpose_perm = spec["transpose_perm"]
17+
18+
if direction == "forward":
19+
reshape_size = spec["reshape_size"]
20+
21+
sliced_weights = weights_in[parent_range].reshape(parent_shape)[
22+
weights_subrange
23+
]
24+
25+
value_to_assign = sliced_weights.transpose(transpose_perm).reshape(reshape_size)
26+
weights_out = weights_out.at[child_range].set(value_to_assign)
27+
28+
elif direction == "backward":
29+
transpose_child_shape = spec["transpose_child_shape"]
30+
child_shape = spec["child_shape"]
31+
32+
sliced_weights = (
33+
weights_in[child_range]
34+
.reshape(transpose_child_shape)
35+
.transpose(transpose_perm)
36+
)
37+
38+
value_to_insert = sliced_weights.flatten().reshape(child_shape)
39+
40+
slab = weights_out[parent_range]
41+
slab_reshaped = slab.reshape(parent_shape)
42+
slab_reshaped = slab_reshaped.at[weights_subrange].set(value_to_insert)
43+
weights_out = weights_out.at[parent_range].set(slab_reshaped.reshape(slab.shape))
44+
45+
return weights_out
46+
47+
def reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim):
48+
weights_in_jax = jnp.array(weights_in)
49+
result = reorder_jax_helper(schedule, weights_in_jax, direction, has_batch_dim)
50+
return np.array(result)
51+
52+
def reorder_jax(schedule, weights_in, direction, has_batch_dim):
53+
if isinstance(weights_in, (jnp.ndarray, jax.Array)):
54+
return reorder_jax_helper(schedule, weights_in, direction, has_batch_dim)
55+
else:
56+
return reorder_numpy_jax_helper(schedule, weights_in, direction, has_batch_dim)

0 commit comments

Comments
 (0)