@@ -77,6 +77,12 @@ def __init__(self, config):
7777 def forward (self , X , Y , W ):
7878 return forward (X , Y , W , self .L3_dim , self .config .irrep_dtype , self .attrs )
7979
80+
81+ def jax_to_torch (x ):
82+ import numpy as np
83+ import torch
84+ return torch .tensor (np .asarray (x ), requires_grad = True )
85+
8086if __name__ == "__main__" :
8187 tp_problem = None
8288 X_ir , Y_ir , Z_ir = Irreps ("1x2e" ), Irreps ("1x3e" ), Irreps ("1x2e" )
@@ -86,28 +92,43 @@ def forward(self, X, Y, W):
8692 shared_weights = False ,
8793 internal_weights = False )
8894 tensor_product = TensorProduct (problem )
89- batch_size = 1
95+ batch_size = 100
9096
9197 X = jax .random .uniform (jax .random .PRNGKey (0 ), (batch_size , X_ir .dim ), dtype = jax .numpy .float32 )
9298 Y = jax .random .uniform (jax .random .PRNGKey (1 ), (batch_size , Y_ir .dim ), dtype = jax .numpy .float32 )
9399 W = jax .random .uniform (jax .random .PRNGKey (2 ), (batch_size , tensor_product .weight_numel ), dtype = jax .numpy .float32 )
94100 Z = tensor_product .forward (X , Y , W )
95101
96102 # Test forward jax vjp
97- ctZ = jnp . ones_like ( Z )
103+ ctZ = jax . random . uniform ( jax . random . PRNGKey ( 3 ), Z . shape , dtype = jax . numpy . float32 )
98104 result = jax .vjp (lambda x , y , w : tensor_product .forward (x , y , w ), X , Y , W )[1 ](ctZ )
99105
100- print (result )
101106 print ("COMPLETED FORWARD PASS!" )
102107
103- # Test the double backward pass
104- ddX = jnp . ones_like ( X )
105- ddY = jnp . ones_like ( Y )
106- ddW = jnp . ones_like ( W )
108+ ddX = jax . random . uniform ( jax . random . PRNGKey ( 4 ), X . shape , dtype = jax . numpy . float32 )
109+ ddY = jax . random . uniform ( jax . random . PRNGKey ( 5 ), Y . shape , dtype = jax . numpy . float32 )
110+ ddW = jax . random . uniform ( jax . random . PRNGKey ( 6 ), W . shape , dtype = jax . numpy . float32 )
111+
107112 result_double_backward = jax .vjp (
108113 lambda x , y , w : jax .vjp (lambda a , b , c : tensor_product .forward (a , b , c ), x , y , w )[1 ](ctZ ),
109114 X , Y , W
110115 )[1 ]((ddX , ddY , ddW ))
111116
112- print (result_double_backward )
113- print ("COMPLETED DOUBLE BACKWARD PASS!" )
117+ print ("COMPLETED DOUBLE BACKWARD PASS!" )
118+
119+ from e3nn import o3
120+ e3nn_tp = o3 .TensorProduct (X_ir , Y_ir , Z_ir , instructions , shared_weights = False , internal_weights = False )
121+ print (jax_to_torch (W ).shape )
122+
123+ X_t = jax_to_torch (X )
124+ Y_t = jax_to_torch (Y )
125+ W_t = jax_to_torch (W )
126+ Z_t = jax_to_torch (Z )
127+ Z_e3nn = e3nn_tp (X_t , Y_t , W_t )
128+ print ("E3NN RESULT:" , (Z_e3nn - Z_t ).norm ())
129+
130+ Z_e3nn .backward (jax_to_torch (ctZ ))
131+ #^^^ Print the norms of the differences in gradients instead
132+ print ("E3NN GRADS NORM:" , (jax_to_torch (result [0 ]) - X_t .grad ).norm (),
133+ (jax_to_torch (result [1 ]) - Y_t .grad ).norm (),
134+ (jax_to_torch (result [2 ]) - W_t .grad ).norm ())
0 commit comments