Skip to content

Commit 2524f2a

Browse files
committed
Added __call__ functions.
1 parent d94db28 commit 2524f2a

2 files changed

Lines changed: 16 additions & 1 deletion

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def __init__(self, config: TPProblem):
6969
def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray:
7070
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
7171

72+
def __call__(self,
73+
X: jax.numpy.ndarray,
74+
Y: jax.numpy.ndarray,
75+
W: jax.numpy.ndarray) -> jax.numpy.ndarray:
76+
return self.forward(X, Y, W)
77+
7278

7379
def jax_to_torch(x):
7480
import numpy as np

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool =
5555
super().__init__(
5656
config,
5757
dp, extlib.postprocess_kernel,
58-
idx_dtype=np.int32, # Note: this is distinct from PyTorch
58+
idx_dtype=np.int32, # N.B. this is distinct from the PyTorch version
5959
torch_op=False,
6060
deterministic=deterministic,
6161
kahan=kahan
@@ -99,6 +99,15 @@ def forward(
9999
self.L3_dim,
100100
self.config.irrep_dtype,
101101
self.attrs)
102+
103+
def __call__(self,
104+
X: jax.numpy.ndarray,
105+
Y: jax.numpy.ndarray,
106+
W: jax.numpy.ndarray,
107+
rows: jax.numpy.ndarray,
108+
cols: jax.numpy.ndarray,
109+
sender_perm: Optional[jax.numpy.ndarray] = None) -> jax.numpy.ndarray:
110+
return self.forward(X, Y, W, rows, cols, sender_perm)
102111

103112
if __name__=="__main__":
104113
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")

0 commit comments

Comments
 (0)