Skip to content

Commit 54f2ee8

Browse files
committed
Merge branch 'main' into ir_mul
2 parents 1eac81f + 519d003 commit 54f2ee8

5 files changed

Lines changed: 46 additions & 3 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
## Latest Changes
22

3+
### v0.6.4 (2026-03-05)
4+
Bugfix: added missing MLIR lowerings for
5+
a pair of JAX primitives (thanks @teddykoker!)
6+
37
### v0.6.3 (2025-02-23)
48
OpenEquivariance v0.6.3 brings long-needed improvements to the
59
PyTorch frontend. We strongly encourage all users to upgrade

openequivariance/openequivariance/jax/jvp/conv_prim.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ def conv_fwd_jvp_abstract_eval(
161161

162162
conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl)
163163
conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval)
164+
mlir.register_lowering(
165+
conv_fwd_jvp_p,
166+
mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False),
167+
platform="cuda",
168+
)
169+
mlir.register_lowering(
170+
conv_fwd_jvp_p,
171+
mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False),
172+
platform="rocm",
173+
)
164174

165175

166176
# ==============================================================================
@@ -285,6 +295,16 @@ def conv_bwd_jvp_abstract_eval(
285295

286296
conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl)
287297
conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval)
298+
mlir.register_lowering(
299+
conv_bwd_jvp_p,
300+
mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True),
301+
platform="cuda",
302+
)
303+
mlir.register_lowering(
304+
conv_bwd_jvp_p,
305+
mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True),
306+
platform="rocm",
307+
)
288308

289309

290310
# ==============================================================================

openequivariance/openequivariance/jax/jvp/tp_prim.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash):
132132

133133
tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl)
134134
tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval)
135+
mlir.register_lowering(
136+
tp_fwd_jvp_p,
137+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
138+
platform="cuda",
139+
)
140+
mlir.register_lowering(
141+
tp_fwd_jvp_p,
142+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
143+
platform="rocm",
144+
)
135145

136146

137147
# ==============================================================================
@@ -225,7 +235,16 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash):
225235

226236
tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl)
227237
tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval)
228-
238+
mlir.register_lowering(
239+
tp_bwd_jvp_p,
240+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
241+
platform="cuda",
242+
)
243+
mlir.register_lowering(
244+
tp_bwd_jvp_p,
245+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
246+
platform="rocm",
247+
)
229248

230249
# ==============================================================================
231250
# 9. Transpose Rule for Backward JVP

openequivariance/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "scikit_build_core.build"
44

55
[project]
66
name = "openequivariance"
7-
version = "0.6.3"
7+
version = "0.6.4"
88
authors = [
99
{ name="Austin Glover" },
1010
{ name="Vivek Bharadwaj" },

openequivariance_extjax/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
88

99
[project]
1010
name = "openequivariance_extjax"
11-
version = "0.6.3"
11+
version = "0.6.4"
1212

1313
authors = [
1414
{ name="Austin Glover" },

0 commit comments

Comments
 (0)