Skip to content

Commit a1b6248

Browse files
committed
Prepping to add tests.
1 parent 2524f2a commit a1b6248

2 files changed

Lines changed: 3 additions & 78 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -73,59 +73,4 @@ def __call__(self,
7373
X: jax.numpy.ndarray,
7474
Y: jax.numpy.ndarray,
7575
W: jax.numpy.ndarray) -> jax.numpy.ndarray:
76-
return self.forward(X, Y, W)
77-
78-
79-
def jax_to_torch(x):
80-
import numpy as np
81-
import torch
82-
return torch.tensor(np.asarray(x), requires_grad=True)
83-
84-
if __name__ == "__main__":
85-
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")
86-
instructions=[(0, 0, 0, "uvu", True)]
87-
problem = TPProblem(X_ir, Y_ir, Z_ir,
88-
instructions,
89-
shared_weights=False,
90-
internal_weights=False)
91-
tensor_product = TensorProduct(problem)
92-
batch_size = 100
93-
94-
X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32)
95-
Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32)
96-
W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32)
97-
Z = tensor_product.forward(X, Y, W)
98-
99-
# Test forward jax vjp
100-
ctZ = jax.random.uniform(jax.random.PRNGKey(3), Z.shape, dtype=jax.numpy.float32)
101-
result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ)
102-
103-
print("COMPLETED FORWARD PASS!")
104-
105-
ddX = jax.random.uniform(jax.random.PRNGKey(4), X.shape, dtype=jax.numpy.float32)
106-
ddY = jax.random.uniform(jax.random.PRNGKey(5), Y.shape, dtype=jax.numpy.float32)
107-
ddW = jax.random.uniform(jax.random.PRNGKey(6), W.shape, dtype=jax.numpy.float32)
108-
109-
result_double_backward = jax.vjp(
110-
lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ),
111-
X, Y, W
112-
)[1]((ddX, ddY, ddW))
113-
114-
print("COMPLETED DOUBLE BACKWARD PASS!")
115-
116-
from e3nn import o3
117-
e3nn_tp = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False)
118-
print(jax_to_torch(W).shape)
119-
120-
X_t = jax_to_torch(X)
121-
Y_t = jax_to_torch(Y)
122-
W_t = jax_to_torch(W)
123-
Z_t = jax_to_torch(Z)
124-
Z_e3nn = e3nn_tp(X_t, Y_t, W_t)
125-
print("E3NN RESULT:", (Z_e3nn - Z_t).norm())
126-
127-
Z_e3nn.backward(jax_to_torch(ctZ))
128-
#^^^ Print the norms of the differences in gradients instead
129-
print("E3NN GRADS NORM:", (jax_to_torch(result[0]) - X_t.grad).norm(),
130-
(jax_to_torch(result[1]) - Y_t.grad).norm(),
131-
(jax_to_torch(result[2]) - W_t.grad).norm())
76+
return self.forward(X, Y, W)

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
3030
jax.ShapeDtypeStruct(W.shape, irrep_dtype)))
3131
return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)
3232

33-
def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
34-
return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs), (X, Y, W, dZ, rows, cols, sender_perm, workspace)
33+
def backward_with_inputs(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
34+
return backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs), (X, Y, W, dZ) #rows, cols, sender_perm, workspace)
3535

3636
def double_backward(rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives):
3737
double_backward_call = jax.ffi.ffi_call("conv_double_backward",
@@ -109,23 +109,3 @@ def __call__(self,
109109
sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray:
110110
return self.forward(X, Y, W, rows, cols, sender_perm)
111111

112-
if __name__=="__main__":
113-
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")
114-
instructions=[(0, 0, 0, "uvu", True)]
115-
problem = TPProblem(X_ir, Y_ir, Z_ir,
116-
instructions,
117-
shared_weights=False,
118-
internal_weights=False)
119-
120-
conv = TensorProductConv(problem, deterministic=False, kahan=False)
121-
122-
node_ct, nonzero_ct = 3, 4
123-
X = jax.random.uniform(jax.random.PRNGKey(0), (node_ct, X_ir.dim), dtype=jax.numpy.float32)
124-
Y = jax.random.uniform(jax.random.PRNGKey(1), (nonzero_ct, Y_ir.dim), dtype=jax.numpy.float32)
125-
W = jax.random.uniform(jax.random.PRNGKey(2), (nonzero_ct, conv.weight_numel), dtype=jax.numpy.float32)
126-
rows = jnp.array([0, 1, 1, 2], dtype=jnp.int32)
127-
cols = jnp.array([1, 0, 2, 1], dtype=jnp.int32)
128-
Z = conv.forward(X, Y, W, rows, cols)
129-
print("Z:", Z)
130-
131-
print("COMPLETE!")

0 commit comments

Comments
 (0)