@@ -30,8 +30,8 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
3030 jax .ShapeDtypeStruct (W .shape , irrep_dtype )))
3131 return backward_call (X , Y , W , dZ , rows , cols , workspace , sender_perm , ** attrs )
3232
33- def backward_with_inputs (X , Y , W , dZ , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
34- return backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ), (X , Y , W , dZ , rows , cols , sender_perm , workspace )
33+ def backward_with_inputs (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ):
34+ return backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ), (X , Y , W , dZ ) # rows, cols, sender_perm, workspace)
3535
3636def double_backward (rows , cols , workspace , sender_perm , irrep_dtype , attrs , inputs , derivatives ):
3737 double_backward_call = jax .ffi .ffi_call ("conv_double_backward" ,
@@ -109,23 +109,3 @@ def __call__(self,
109109 sender_perm : Optional [jax .numpy .ndarray ] = None ) -> jax .numpy .ndarray :
110110 return self .forward (X , Y , W , rows , cols , sender_perm )
111111
112- if __name__ == "__main__" :
113- X_ir , Y_ir , Z_ir = Irreps ("1x2e" ), Irreps ("1x3e" ), Irreps ("1x2e" )
114- instructions = [(0 , 0 , 0 , "uvu" , True )]
115- problem = TPProblem (X_ir , Y_ir , Z_ir ,
116- instructions ,
117- shared_weights = False ,
118- internal_weights = False )
119-
120- conv = TensorProductConv (problem , deterministic = False , kahan = False )
121-
122- node_ct , nonzero_ct = 3 , 4
123- X = jax .random .uniform (jax .random .PRNGKey (0 ), (node_ct , X_ir .dim ), dtype = jax .numpy .float32 )
124- Y = jax .random .uniform (jax .random .PRNGKey (1 ), (nonzero_ct , Y_ir .dim ), dtype = jax .numpy .float32 )
125- W = jax .random .uniform (jax .random .PRNGKey (2 ), (nonzero_ct , conv .weight_numel ), dtype = jax .numpy .float32 )
126- rows = jnp .array ([0 , 1 , 1 , 2 ], dtype = jnp .int32 )
127- cols = jnp .array ([1 , 0 , 2 , 1 ], dtype = jnp .int32 )
128- Z = conv .forward (X , Y , W , rows , cols )
129- print ("Z:" , Z )
130-
131- print ("COMPLETE!" )
0 commit comments