Skip to content

Commit c6ba580

Browse files
authored
Register lowering for conv_fwd_jvp_p and conv_bwd_jvp_p (#190)
* register lowering for conv_fwd_jvp_p and conv_bwd_jvp_p * prek -a * add tp too
1 parent 64ef9f8 commit c6ba580

2 files changed

Lines changed: 40 additions & 1 deletion

File tree

openequivariance/openequivariance/jax/jvp/conv_prim.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,16 @@ def conv_fwd_jvp_abstract_eval(
161161

162162
conv_fwd_jvp_p.def_impl(conv_fwd_jvp_impl)
163163
conv_fwd_jvp_p.def_abstract_eval(conv_fwd_jvp_abstract_eval)
164+
mlir.register_lowering(
165+
conv_fwd_jvp_p,
166+
mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False),
167+
platform="cuda",
168+
)
169+
mlir.register_lowering(
170+
conv_fwd_jvp_p,
171+
mlir.lower_fun(conv_fwd_jvp_impl, multiple_results=False),
172+
platform="rocm",
173+
)
164174

165175

166176
# ==============================================================================
@@ -285,6 +295,16 @@ def conv_bwd_jvp_abstract_eval(
285295

286296
conv_bwd_jvp_p.def_impl(conv_bwd_jvp_impl)
287297
conv_bwd_jvp_p.def_abstract_eval(conv_bwd_jvp_abstract_eval)
298+
mlir.register_lowering(
299+
conv_bwd_jvp_p,
300+
mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True),
301+
platform="cuda",
302+
)
303+
mlir.register_lowering(
304+
conv_bwd_jvp_p,
305+
mlir.lower_fun(conv_bwd_jvp_impl, multiple_results=True),
306+
platform="rocm",
307+
)
288308

289309

290310
# ==============================================================================

openequivariance/openequivariance/jax/jvp/tp_prim.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,16 @@ def tp_fwd_jvp_abstract_eval(X, Y, W, dX, dY, dW, *, L3_dim, kernel, hash):
132132

133133
tp_fwd_jvp_p.def_impl(tp_fwd_jvp_impl)
134134
tp_fwd_jvp_p.def_abstract_eval(tp_fwd_jvp_abstract_eval)
135+
mlir.register_lowering(
136+
tp_fwd_jvp_p,
137+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
138+
platform="cuda",
139+
)
140+
mlir.register_lowering(
141+
tp_fwd_jvp_p,
142+
mlir.lower_fun(tp_fwd_jvp_impl, multiple_results=False),
143+
platform="rocm",
144+
)
135145

136146

137147
# ==============================================================================
@@ -225,7 +235,16 @@ def tp_bwd_jvp_abstract_eval(X, Y, W, dZ, tX, tY, tW, tdZ, *, kernel, hash):
225235

226236
tp_bwd_jvp_p.def_impl(tp_bwd_jvp_impl)
227237
tp_bwd_jvp_p.def_abstract_eval(tp_bwd_jvp_abstract_eval)
228-
238+
mlir.register_lowering(
239+
tp_bwd_jvp_p,
240+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
241+
platform="cuda",
242+
)
243+
mlir.register_lowering(
244+
tp_bwd_jvp_p,
245+
mlir.lower_fun(tp_bwd_jvp_impl, multiple_results=True),
246+
platform="rocm",
247+
)
229248

230249
# ==============================================================================
231250
# 9. Transpose Rule for Backward JVP

0 commit comments

Comments
 (0)