Skip to content

Commit 7b1ce90

Browse files
committed
Did some extra testing.
1 parent 673b5ee commit 7b1ce90

1 file changed

Lines changed: 30 additions & 9 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,12 @@ def __init__(self, config):
7777
def forward(self, X, Y, W):
7878
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
7979

80+
81+
def jax_to_torch(x):
82+
import numpy as np
83+
import torch
84+
return torch.tensor(np.asarray(x), requires_grad=True)
85+
8086
if __name__ == "__main__":
8187
tp_problem = None
8288
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")
@@ -86,28 +92,43 @@ def forward(self, X, Y, W):
8692
shared_weights=False,
8793
internal_weights=False)
8894
tensor_product = TensorProduct(problem)
89-
batch_size = 1
95+
batch_size = 100
9096

9197
X = jax.random.uniform(jax.random.PRNGKey(0), (batch_size, X_ir.dim), dtype=jax.numpy.float32)
9298
Y = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, Y_ir.dim), dtype=jax.numpy.float32)
9399
W = jax.random.uniform(jax.random.PRNGKey(2), (batch_size, tensor_product.weight_numel), dtype=jax.numpy.float32)
94100
Z = tensor_product.forward(X, Y, W)
95101

96102
# Test forward jax vjp
97-
ctZ = jnp.ones_like(Z)
103+
ctZ = jax.random.uniform(jax.random.PRNGKey(3), Z.shape, dtype=jax.numpy.float32)
98104
result = jax.vjp(lambda x, y, w: tensor_product.forward(x, y, w), X, Y, W)[1](ctZ)
99105

100-
print(result)
101106
print("COMPLETED FORWARD PASS!")
102107

103-
# Test the double backward pass
104-
ddX = jnp.ones_like(X)
105-
ddY = jnp.ones_like(Y)
106-
ddW = jnp.ones_like(W)
108+
ddX = jax.random.uniform(jax.random.PRNGKey(4), X.shape, dtype=jax.numpy.float32)
109+
ddY = jax.random.uniform(jax.random.PRNGKey(5), Y.shape, dtype=jax.numpy.float32)
110+
ddW = jax.random.uniform(jax.random.PRNGKey(6), W.shape, dtype=jax.numpy.float32)
111+
107112
result_double_backward = jax.vjp(
108113
lambda x, y, w: jax.vjp(lambda a, b, c: tensor_product.forward(a, b, c), x, y, w)[1](ctZ),
109114
X, Y, W
110115
)[1]((ddX, ddY, ddW))
111116

112-
print(result_double_backward)
113-
print("COMPLETED DOUBLE BACKWARD PASS!")
117+
print("COMPLETED DOUBLE BACKWARD PASS!")
118+
119+
from e3nn import o3
120+
e3nn_tp = o3.TensorProduct(X_ir, Y_ir, Z_ir, instructions, shared_weights=False, internal_weights=False)
121+
print(jax_to_torch(W).shape)
122+
123+
X_t = jax_to_torch(X)
124+
Y_t = jax_to_torch(Y)
125+
W_t = jax_to_torch(W)
126+
Z_t = jax_to_torch(Z)
127+
Z_e3nn = e3nn_tp(X_t, Y_t, W_t)
128+
print("E3NN RESULT:", (Z_e3nn - Z_t).norm())
129+
130+
Z_e3nn.backward(jax_to_torch(ctZ))
131+
#^^^ Print the norms of the differences in gradients instead
132+
print("E3NN GRADS NORM:", (jax_to_torch(result[0]) - X_t.grad).norm(),
133+
(jax_to_torch(result[1]) - Y_t.grad).norm(),
134+
(jax_to_torch(result[2]) - W_t.grad).norm())

0 commit comments

Comments
 (0)