Skip to content

Commit 3c2dd77

Browse files
Arbitrary higher derivative support for JAX and jax.jit fix. (#178)
* Triple backward written. * Triple backward seems to work. * Fixed things up. * Fixed issues. * Precommit.
1 parent 4fadfc3 commit 3c2dd77

4 files changed

Lines changed: 186 additions & 74 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ Z = tp_conv.forward(
183183
X, Y, W, edge_index[0], edge_index[1]
184184
)
185185
print(jax.numpy.linalg.norm(Z))
186+
187+
# Test JAX JIT
188+
jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
189+
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))
186190
```
187191

188192
## Citation and Acknowledgements

openequivariance/openequivariance/jax/TensorProduct.py

Lines changed: 67 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import jax
2+
import jax.numpy as jnp
23
import numpy as np
34
from functools import partial
45
from openequivariance.jax import extlib
@@ -16,10 +17,18 @@ def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
1617
return forward_call(X, Y, W, **attrs)
1718

1819

19-
def forward_with_inputs(X, Y, W, L3_dim, irrep_dtype, attrs):
20+
def forward_fwd(X, Y, W, L3_dim, irrep_dtype, attrs):
2021
return forward(X, Y, W, L3_dim, irrep_dtype, attrs), (X, Y, W)
2122

2223

24+
def forward_bwd(L3_dim, irrep_dtype, attrs, inputs, dZ):
25+
X, Y, W = inputs
26+
return backward(X, Y, W, dZ, irrep_dtype, attrs)
27+
28+
29+
forward.defvjp(forward_fwd, forward_bwd)
30+
31+
2332
@partial(jax.custom_vjp, nondiff_argnums=(4, 5))
2433
def backward(X, Y, W, dZ, irrep_dtype, attrs):
2534
backward_call = jax.ffi.ffi_call(
@@ -30,33 +39,78 @@ def backward(X, Y, W, dZ, irrep_dtype, attrs):
3039
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
3140
),
3241
)
33-
3442
return backward_call(X, Y, W, dZ, **attrs)
3543

3644

37-
def backward_with_inputs(X, Y, W, dZ, irrep_dtype, attrs):
45+
def backward_fwd(X, Y, W, dZ, irrep_dtype, attrs):
3846
return backward(X, Y, W, dZ, irrep_dtype, attrs), (X, Y, W, dZ)
3947

4048

41-
def double_backward(irrep_dtype, attrs, inputs, derivatives):
49+
def backward_bwd(irrep_dtype, attrs, inputs, derivs):
50+
X, Y, W, dZ = inputs
51+
ddX, ddY, ddW = derivs
52+
return double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs)
53+
54+
55+
backward.defvjp(backward_fwd, backward_bwd)
56+
57+
58+
@partial(jax.custom_vjp, nondiff_argnums=(7, 8))
59+
def double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs):
4260
double_backward_call = jax.ffi.ffi_call(
4361
"tp_double_backward",
4462
(
45-
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
46-
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
47-
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
48-
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
63+
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
64+
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
65+
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
66+
jax.ShapeDtypeStruct(dZ.shape, irrep_dtype),
4967
),
5068
)
51-
return double_backward_call(*inputs, *derivatives, **attrs)
69+
return double_backward_call(X, Y, W, dZ, ddX, ddY, ddW, **attrs)
70+
71+
72+
def double_backward_fwd(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs):
73+
out = double_backward(X, Y, W, dZ, ddX, ddY, ddW, irrep_dtype, attrs)
74+
return out, (X, Y, W, dZ, ddX, ddY, ddW)
75+
76+
77+
def zeros_like(x):
78+
return jnp.zeros_like(x)
79+
80+
81+
def triple_backward(irrep_dtype, attrs, residuals, tangent_outputs):
82+
X, Y, W, dZ, ddX, ddY, ddW = residuals
83+
t_dX, t_dY, t_dW, t_ddZ = tangent_outputs
84+
85+
op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W))
86+
g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, irrep_dtype, attrs)
87+
88+
op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW))
89+
g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, irrep_dtype, attrs)
90+
91+
op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW)
92+
g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, irrep_dtype, attrs)
93+
94+
op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW)
95+
g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, irrep_dtype, attrs)
96+
97+
g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, irrep_dtype, attrs)
98+
g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, irrep_dtype, attrs)
99+
g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, irrep_dtype, attrs)
100+
101+
grad_X = g2_X + g4_X + g6_X + g7_X
102+
grad_Y = g2_Y + g3_Y + g5_Y + g7_Y
103+
grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
104+
grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ
52105

106+
grad_ddX = g1_ddX + g3_ddX + g5_ddX
107+
grad_ddY = g1_ddY + g4_ddY + g6_ddY
108+
grad_ddW = g2_ddW + g7_ddW
53109

54-
def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
55-
return backward(inputs[0], inputs[1], inputs[2], dZ, irrep_dtype, attrs)
110+
return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW
56111

57112

58-
forward.defvjp(forward_with_inputs, backward_autograd)
59-
backward.defvjp(backward_with_inputs, double_backward)
113+
double_backward.defvjp(double_backward_fwd, triple_backward)
60114

61115

62116
class TensorProduct(LoopUnrollTP):

openequivariance/openequivariance/jax/TensorProductConv.py

Lines changed: 112 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import jax
2+
import jax.numpy as jnp
13
import numpy as np
24
from functools import partial
35
from typing import Optional
@@ -8,31 +10,44 @@
810
from openequivariance.core.utils import hash_attributes
911
from openequivariance.jax.utils import reorder_jax
1012

11-
import jax
12-
import jax.numpy as jnp
13-
1413
from openequivariance.benchmark.logging_utils import getLogger
1514

1615
logger = getLogger()
1716

1817

19-
@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9))
18+
def zeros_like(x):
19+
return jnp.zeros_like(x)
20+
21+
22+
@partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9))
2023
def forward(X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs):
2124
forward_call = jax.ffi.ffi_call(
2225
"conv_forward", jax.ShapeDtypeStruct((X.shape[0], L3_dim), irrep_dtype)
2326
)
2427
return forward_call(X, Y, W, rows, cols, workspace, sender_perm, **attrs)
2528

2629

27-
def forward_with_inputs(
30+
def forward_fwd(
2831
X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
2932
):
30-
return forward(
33+
out = forward(
3134
X, Y, W, rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs
32-
), (X, Y, W, rows, cols, sender_perm, workspace)
35+
)
36+
return out, (X, Y, W, rows, cols)
37+
38+
39+
def forward_bwd(workspace, sender_perm, L3_dim, irrep_dtype, attrs, res, dZ):
40+
X, Y, W, rows, cols = res
41+
dX, dY, dW = backward(
42+
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
43+
)
44+
return dX, dY, dW, None, None
3345

3446

35-
@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9))
47+
forward.defvjp(forward_fwd, forward_bwd)
48+
49+
50+
@partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9))
3651
def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
3752
backward_call = jax.ffi.ffi_call(
3853
"conv_backward",
@@ -45,65 +60,121 @@ def backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
4560
return backward_call(X, Y, W, dZ, rows, cols, workspace, sender_perm, **attrs)
4661

4762

48-
def backward_with_inputs(
49-
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
50-
):
51-
return backward(
52-
X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs
53-
), (X, Y, W, dZ) # rows, cols, sender_perm, workspace)
63+
def backward_fwd(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs):
64+
out = backward(X, Y, W, dZ, rows, cols, workspace, sender_perm, irrep_dtype, attrs)
65+
return out, (X, Y, W, dZ, rows, cols)
66+
67+
68+
def backward_bwd(workspace, sender_perm, irrep_dtype, attrs, res, derivatives):
69+
X, Y, W, dZ, rows, cols = res
70+
ddX, ddY, ddW = derivatives
71+
72+
gX, gY, gW, gdZ = double_backward(
73+
X,
74+
Y,
75+
W,
76+
dZ,
77+
ddX,
78+
ddY,
79+
ddW,
80+
rows,
81+
cols,
82+
workspace,
83+
sender_perm,
84+
irrep_dtype,
85+
attrs,
86+
)
87+
88+
return gX, gY, gW, gdZ, None, None
89+
5490

91+
backward.defvjp(backward_fwd, backward_bwd)
5592

93+
94+
@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12))
5695
def double_backward(
57-
rows, cols, workspace, sender_perm, irrep_dtype, attrs, inputs, derivatives
96+
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs
5897
):
5998
double_backward_call = jax.ffi.ffi_call(
6099
"conv_double_backward",
61100
(
62-
jax.ShapeDtypeStruct(inputs[0].shape, irrep_dtype),
63-
jax.ShapeDtypeStruct(inputs[1].shape, irrep_dtype),
64-
jax.ShapeDtypeStruct(inputs[2].shape, irrep_dtype),
65-
jax.ShapeDtypeStruct(inputs[3].shape, irrep_dtype),
101+
jax.ShapeDtypeStruct(X.shape, irrep_dtype),
102+
jax.ShapeDtypeStruct(Y.shape, irrep_dtype),
103+
jax.ShapeDtypeStruct(W.shape, irrep_dtype),
104+
jax.ShapeDtypeStruct(dZ.shape, irrep_dtype),
66105
),
67106
)
68107
return double_backward_call(
69-
*inputs, *derivatives, rows, cols, workspace, sender_perm, **attrs
108+
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, **attrs
70109
)
71110

72111

73-
def backward_autograd(
74-
rows, cols, workspace, sender_perm, L3_dim, irrep_dtype, attrs, inputs, dZ
112+
def double_backward_fwd(
113+
X, Y, W, dZ, ddX, ddY, ddW, rows, cols, workspace, sender_perm, irrep_dtype, attrs
75114
):
76-
return backward(
77-
inputs[0],
78-
inputs[1],
79-
inputs[2],
115+
out = double_backward(
116+
X,
117+
Y,
118+
W,
80119
dZ,
120+
ddX,
121+
ddY,
122+
ddW,
81123
rows,
82124
cols,
83125
workspace,
84126
sender_perm,
85127
irrep_dtype,
86128
attrs,
87129
)
130+
return out, (X, Y, W, dZ, ddX, ddY, ddW, rows, cols)
88131

89132

90-
forward.defvjp(forward_with_inputs, backward_autograd)
91-
backward.defvjp(backward_with_inputs, double_backward)
133+
def triple_backward(
134+
workspace,
135+
sender_perm,
136+
irrep_dtype,
137+
attrs,
138+
residuals,
139+
tangent_outputs,
140+
):
141+
X, Y, W, dZ, ddX, ddY, ddW, rows, cols = residuals
142+
t_dX, t_dY, t_dW, t_ddZ = tangent_outputs
92143

144+
common_args = (rows, cols, workspace, sender_perm, irrep_dtype, attrs)
93145

94-
class TensorProductConv(LoopUnrollConv):
95-
r"""
96-
Identical to ``oeq.torch.TensorProductConv`` with functionality in JAX, with one
97-
key difference: integer arrays passed to this function must have dtype
98-
``np.int32`` (as opposed to ``np.int64`` in the PyTorch version).
99-
100-
:param problem: Specification of the tensor product.
101-
:param deterministic: if ``False``, uses atomics for the convolution. If ``True``, uses a deterministic
102-
fixup-based algorithm. `Default`: ``False``.
103-
:param kahan: If ``True``, uses Kahan summation to improve accuracy during aggregation. To use this option,
104-
the input tensors must be in float32 precision AND you must set ``deterministic=True``. *Default*: ``False``.
105-
"""
146+
op1_inputs = (ddX, ddY, W, dZ, t_dX, t_dY, zeros_like(W))
147+
g1_ddX, g1_ddY, g1_W, g1_dZ = double_backward(*op1_inputs, *common_args)
106148

149+
op2_inputs = (X, Y, ddW, dZ, t_dX, t_dY, zeros_like(ddW))
150+
g2_X, g2_Y, g2_ddW, g2_dZ = double_backward(*op2_inputs, *common_args)
151+
152+
op3_inputs = (ddX, Y, W, dZ, zeros_like(ddX), zeros_like(Y), t_dW)
153+
g3_ddX, g3_Y, g3_W, g3_dZ = double_backward(*op3_inputs, *common_args)
154+
155+
op4_inputs = (X, ddY, W, dZ, zeros_like(X), zeros_like(ddY), t_dW)
156+
g4_X, g4_ddY, g4_W, g4_dZ = double_backward(*op4_inputs, *common_args)
157+
158+
g5_ddX, g5_Y, g5_W = backward(ddX, Y, W, t_ddZ, *common_args)
159+
g6_X, g6_ddY, g6_W = backward(X, ddY, W, t_ddZ, *common_args)
160+
g7_X, g7_Y, g7_ddW = backward(X, Y, ddW, t_ddZ, *common_args)
161+
162+
grad_X = g2_X + g4_X + g6_X + g7_X
163+
grad_Y = g2_Y + g3_Y + g5_Y + g7_Y
164+
grad_W = g1_W + g3_W + g4_W + g5_W + g6_W
165+
grad_dZ = g1_dZ + g2_dZ + g3_dZ + g4_dZ
166+
167+
grad_ddX = g1_ddX + g3_ddX + g5_ddX
168+
grad_ddY = g1_ddY + g4_ddY + g6_ddY
169+
grad_ddW = g2_ddW + g7_ddW
170+
171+
return grad_X, grad_Y, grad_W, grad_dZ, grad_ddX, grad_ddY, grad_ddW, None, None
172+
173+
174+
double_backward.defvjp(double_backward_fwd, triple_backward)
175+
176+
177+
class TensorProductConv(LoopUnrollConv):
107178
def __init__(
108179
self, config: TPProblem, deterministic: bool = False, kahan: bool = False
109180
):
@@ -112,7 +183,7 @@ def __init__(
112183
config,
113184
dp,
114185
extlib.postprocess_kernel,
115-
idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version
186+
idx_dtype=np.int32,
116187
torch_op=False,
117188
deterministic=deterministic,
118189
kahan=kahan,
@@ -145,26 +216,6 @@ def forward(
145216
cols: jax.numpy.ndarray,
146217
sender_perm: Optional[jax.numpy.ndarray] = None,
147218
) -> jax.numpy.ndarray:
148-
r"""
149-
Computes the fused CG tensor product + convolution.
150-
151-
:param X: Tensor of shape ``[|V|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
152-
:param Y: Tensor of shape ``[|E|, problem.irreps_in1.dim()]``, datatype ``problem.irrep_dtype``.
153-
:param W: Tensor of datatype ``problem.weight_dtype`` and shape
154-
155-
* ``[|E|, problem.weight_numel]`` if ``problem.shared_weights=False``
156-
* ``[problem.weight_numel]`` if ``problem.shared_weights=True``
157-
158-
:param rows: Tensor of shape ``[|E|]`` with row indices for each nonzero in the adjacency matrix,
159-
datatype ``np.int32``. Must be row-major sorted along with ``cols`` when ``deterministic=True``.
160-
:param cols: Tensor of shape ``[|E|]`` with column indices for each nonzero in the adjacency matrix,
161-
datatype ``np.int32``.
162-
:param sender_perm: Tensor of shape ``[|E|]`` and ``np.int32`` datatype containing a
163-
permutation that transposes the adjacency matrix nonzeros from row-major to column-major order.
164-
Must be provided when ``deterministic=True``.
165-
166-
:return: Tensor of shape ``[|V|, problem.irreps_out.dim()]``, datatype ``problem.irrep_dtype``.
167-
"""
168219
if not self.deterministic:
169220
sender_perm = self.dummy_transpose_perm
170221
else:

tests/example_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,6 @@ def test_tutorial_jax(with_jax):
161161
tp_conv = oeq.jax.TensorProductConv(problem, deterministic=False)
162162
Z = tp_conv.forward(X, Y, W, edge_index[0], edge_index[1])
163163
print(jax.numpy.linalg.norm(Z))
164+
165+
jitted = jax.jit(lambda X, Y, W, e1, e2: tp_conv.forward(X, Y, W, e1, e2))
166+
print(jax.numpy.linalg.norm(jitted(X, Y, W, edge_index[0], edge_index[1])))

0 commit comments

Comments
 (0)