1+ import jax
2+ import jax .numpy as jnp
13import numpy as np
24from functools import partial
35from typing import Optional
810from openequivariance .core .utils import hash_attributes
911from openequivariance .jax .utils import reorder_jax
1012
11- import jax
12- import jax .numpy as jnp
13-
1413from openequivariance .benchmark .logging_utils import getLogger
1514
1615logger = getLogger ()
1716
1817
19- @partial (jax .custom_vjp , nondiff_argnums = (3 , 4 , 5 , 6 , 7 , 8 , 9 ))
18+ def zeros_like (x ):
19+ return jnp .zeros_like (x )
20+
21+
22+ @partial (jax .custom_vjp , nondiff_argnums = (5 , 6 , 7 , 8 , 9 ))
2023def forward (X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs ):
2124 forward_call = jax .ffi .ffi_call (
2225 "conv_forward" , jax .ShapeDtypeStruct ((X .shape [0 ], L3_dim ), irrep_dtype )
2326 )
2427 return forward_call (X , Y , W , rows , cols , workspace , sender_perm , ** attrs )
2528
2629
27- def forward_with_inputs (
30+ def forward_fwd (
2831 X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs
2932):
30- return forward (
33+ out = forward (
3134 X , Y , W , rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs
32- ), (X , Y , W , rows , cols , sender_perm , workspace )
35+ )
36+ return out , (X , Y , W , rows , cols )
37+
38+
39+ def forward_bwd (workspace , sender_perm , L3_dim , irrep_dtype , attrs , res , dZ ):
40+ X , Y , W , rows , cols = res
41+ dX , dY , dW = backward (
42+ X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
43+ )
44+ return dX , dY , dW , None , None
3345
3446
35- @partial (jax .custom_vjp , nondiff_argnums = (4 , 5 , 6 , 7 , 8 , 9 ))
47+ forward .defvjp (forward_fwd , forward_bwd )
48+
49+
50+ @partial (jax .custom_vjp , nondiff_argnums = (6 , 7 , 8 , 9 ))
3651def backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ):
3752 backward_call = jax .ffi .ffi_call (
3853 "conv_backward" ,
@@ -45,65 +60,121 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
4560 return backward_call (X , Y , W , dZ , rows , cols , workspace , sender_perm , ** attrs )
4661
4762
48- def backward_with_inputs (
49- X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
50- ):
51- return backward (
52- X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs
53- ), (X , Y , W , dZ ) # rows, cols, sender_perm, workspace)
63+ def backward_fwd (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs ):
64+ out = backward (X , Y , W , dZ , rows , cols , workspace , sender_perm , irrep_dtype , attrs )
65+ return out , (X , Y , W , dZ , rows , cols )
66+
67+
68+ def backward_bwd (workspace , sender_perm , irrep_dtype , attrs , res , derivatives ):
69+ X , Y , W , dZ , rows , cols = res
70+ ddX , ddY , ddW = derivatives
71+
72+ gX , gY , gW , gdZ = double_backward (
73+ X ,
74+ Y ,
75+ W ,
76+ dZ ,
77+ ddX ,
78+ ddY ,
79+ ddW ,
80+ rows ,
81+ cols ,
82+ workspace ,
83+ sender_perm ,
84+ irrep_dtype ,
85+ attrs ,
86+ )
87+
88+ return gX , gY , gW , gdZ , None , None
89+
5490
91+ backward .defvjp (backward_fwd , backward_bwd )
5592
93+
94+ @partial (jax .custom_vjp , nondiff_argnums = (9 , 10 , 11 , 12 ))
5695def double_backward (
57- rows , cols , workspace , sender_perm , irrep_dtype , attrs , inputs , derivatives
96+ X , Y , W , dZ , ddX , ddY , ddW , rows , cols , workspace , sender_perm , irrep_dtype , attrs
5897):
5998 double_backward_call = jax .ffi .ffi_call (
6099 "conv_double_backward" ,
61100 (
62- jax .ShapeDtypeStruct (inputs [ 0 ] .shape , irrep_dtype ),
63- jax .ShapeDtypeStruct (inputs [ 1 ] .shape , irrep_dtype ),
64- jax .ShapeDtypeStruct (inputs [ 2 ] .shape , irrep_dtype ),
65- jax .ShapeDtypeStruct (inputs [ 3 ] .shape , irrep_dtype ),
101+ jax .ShapeDtypeStruct (X .shape , irrep_dtype ),
102+ jax .ShapeDtypeStruct (Y .shape , irrep_dtype ),
103+ jax .ShapeDtypeStruct (W .shape , irrep_dtype ),
104+ jax .ShapeDtypeStruct (dZ .shape , irrep_dtype ),
66105 ),
67106 )
68107 return double_backward_call (
69- * inputs , * derivatives , rows , cols , workspace , sender_perm , ** attrs
108+ X , Y , W , dZ , ddX , ddY , ddW , rows , cols , workspace , sender_perm , ** attrs
70109 )
71110
72111
73- def backward_autograd (
74- rows , cols , workspace , sender_perm , L3_dim , irrep_dtype , attrs , inputs , dZ
112+ def double_backward_fwd (
113+ X , Y , W , dZ , ddX , ddY , ddW , rows , cols , workspace , sender_perm , irrep_dtype , attrs
75114):
76- return backward (
77- inputs [ 0 ] ,
78- inputs [ 1 ] ,
79- inputs [ 2 ] ,
115+ out = double_backward (
116+ X ,
117+ Y ,
118+ W ,
80119 dZ ,
120+ ddX ,
121+ ddY ,
122+ ddW ,
81123 rows ,
82124 cols ,
83125 workspace ,
84126 sender_perm ,
85127 irrep_dtype ,
86128 attrs ,
87129 )
130+ return out , (X , Y , W , dZ , ddX , ddY , ddW , rows , cols )
88131
89132
90- forward .defvjp (forward_with_inputs , backward_autograd )
91- backward .defvjp (backward_with_inputs , double_backward )
133+ def triple_backward (
134+ workspace ,
135+ sender_perm ,
136+ irrep_dtype ,
137+ attrs ,
138+ residuals ,
139+ tangent_outputs ,
140+ ):
141+ X , Y , W , dZ , ddX , ddY , ddW , rows , cols = residuals
142+ t_dX , t_dY , t_dW , t_ddZ = tangent_outputs
92143
144+ common_args = (rows , cols , workspace , sender_perm , irrep_dtype , attrs )
93145
94- class TensorProductConv (LoopUnrollConv ):
95- r"""
96- Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
97- key difference: integer arrays passed to this function must have dtype
98- ``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).
99-
100- :param problem: Specification of the tensor product.
101- :param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
102- fixup-based algorithm. `Default`: ``False``.
103- :param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
104- the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
105- """
146+ op1_inputs = (ddX , ddY , W , dZ , t_dX , t_dY , zeros_like (W ))
147+ g1_ddX , g1_ddY , g1_W , g1_dZ = double_backward (* op1_inputs , * common_args )
106148
149+ op2_inputs = (X , Y , ddW , dZ , t_dX , t_dY , zeros_like (ddW ))
150+ g2_X , g2_Y , g2_ddW , g2_dZ = double_backward (* op2_inputs , * common_args )
151+
152+ op3_inputs = (ddX , Y , W , dZ , zeros_like (ddX ), zeros_like (Y ), t_dW )
153+ g3_ddX , g3_Y , g3_W , g3_dZ = double_backward (* op3_inputs , * common_args )
154+
155+ op4_inputs = (X , ddY , W , dZ , zeros_like (X ), zeros_like (ddY ), t_dW )
156+ g4_X , g4_ddY , g4_W , g4_dZ = double_backward (* op4_inputs , * common_args )
157+
158+ g5_ddX , g5_Y , g5_W = backward (ddX , Y , W , t_ddZ , * common_args )
159+ g6_X , g6_ddY , g6_W = backward (X , ddY , W , t_ddZ , * common_args )
160+ g7_X , g7_Y , g7_ddW = backward (X , Y , ddW , t_ddZ , * common_args )
161+
162+ grad_X = g2_X + g4_X + g6_X + g7_X
163+ grad_Y = g2_Y + g3_Y + g5_Y + g7_Y
164+ grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
165+ grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ
166+
167+ grad_ddX = g1_ddX + g3_ddX + g5_ddX
168+ grad_ddY = g1_ddY + g4_ddY + g6_ddY
169+ grad_ddW = g2_ddW + g7_ddW
170+
171+ return grad_X , grad_Y , grad_W , grad_dZ , grad_ddX , grad_ddY , grad_ddW , None , None
172+
173+
174+ double_backward .defvjp (double_backward_fwd , triple_backward )
175+
176+
177+ class TensorProductConv (LoopUnrollConv ):
107178 def __init__ (
108179 self , config : TPProblem , deterministic : bool = False , kahan : bool = False
109180 ):
@@ -112,7 +183,7 @@ def __init__(
112183 config ,
113184 dp ,
114185 extlib .postprocess_kernel ,
115- idx_dtype = np .int32 , # N.B. this is distinct from the PyTorch version
186+ idx_dtype = np .int32 ,
116187 torch_op = False ,
117188 deterministic = deterministic ,
118189 kahan = kahan ,
@@ -145,26 +216,6 @@ def forward(
145216 cols : jax .numpy .ndarray ,
146217 sender_perm : Optional [jax .numpy .ndarray ] = None ,
147218 ) -> jax .numpy .ndarray :
148- r"""
149- Computes the fused CG tensor product + convolution.
150-
151- :param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
152- :param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
153- :param W: Tensor of datatype ``problem.weight_dtype`` and shape
154-
155- * ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
156- * ``[problem.weight_numel]`` if ``problem.shared_weights=True``
157-
158- :param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
159- datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
160- :param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
161- datatype ``np.int32``.
162- :param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
163- permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
164- Must be provided when ``deterministic=True``.
165-
166- :return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
167- """
168219 if not self .deterministic :
169220 sender_perm = self .dummy_transpose_perm
170221 else :
0 commit comments