Skip to content

Commit 4ea49dc

Browse files
committed
Added some type annotations.
1 parent b9c9135 commit 4ea49dc

2 files changed

Lines changed: 18 additions & 8 deletions

File tree

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def backward_autograd(L3_dim, irrep_dtype, attrs, inputs, dZ):
5050
backward.defvjp(backward_with_inputs, double_backward)
5151

5252
class TensorProduct(LoopUnrollTP):
53-
def __init__(self, config):
53+
def __init__(self, config: TPProblem):
5454
dp = extlib.DeviceProp(0)
5555
super().__init__(config, dp, extlib.postprocess_kernel, torch_op=False)
5656

@@ -66,7 +66,7 @@ def __init__(self, config):
6666
self.weight_numel = config.weight_numel
6767
self.L3_dim = self.config.irreps_out.dim
6868

69-
def forward(self, X, Y, W):
69+
def forward(self, X: jax.ndarray, Y: jax.ndarray, W: jax.ndarray) -> jax.ndarray:
7070
return forward(X, Y, W, self.L3_dim, self.config.irrep_dtype, self.attrs)
7171

7272

openequivariance/openequivariance/impl_jax/TensorProductConv.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
from functools import partial
3+
from typing import Optional
34
from openequivariance.impl_jax import extlib
45

56
from openequivariance.core.e3nn_lite import TPProblem, Irreps
@@ -13,13 +14,12 @@
1314
logger = getLogger()
1415

1516
class TensorProductConv(LoopUnrollConv):
16-
def __init__(self, config, deterministic=False, kahan=False):
17+
def __init__(self, config: TPProblem, deterministic: bool = False, kahan: bool = False):
1718
dp = extlib.DeviceProp(0)
1819
super().__init__(
19-
self,
20-
config,
20+
config,
2121
dp, extlib.postprocess_kernel,
22-
idx_dtype=np.int64,
22+
idx_dtype=np.int32, # Note: this is distinct from PyTorch
2323
torch_op=False,
2424
deterministic=deterministic,
2525
kahan=kahan
@@ -30,7 +30,7 @@ def __init__(self, config, deterministic=False, kahan=False):
3030
"forward_config": vars(self.forward_schedule.launch_config),
3131
"backward_config": vars(self.backward_schedule.launch_config),
3232
"double_backward_config": vars(self.double_backward_schedule.launch_config),
33-
"kernel_prop": self.kernelProp
33+
"kernel_prop": self.kernel_prop
3434
}
3535
hash_attributes(self.attrs)
3636

@@ -39,8 +39,18 @@ def __init__(self, config, deterministic=False, kahan=False):
3939

4040
self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8)
4141
logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.")
42-
self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int64)
42+
self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int32)
43+
4344

45+
def forward(
46+
self,
47+
X: jax.ndarray,
48+
Y: jax.ndarray,
49+
W: jax.ndarray,
50+
rows: jax.ndarray,
51+
cols: jax.ndarray,
52+
sender_perm: Optional[jax.ndarray] = None) -> jax.ndarray:
53+
pass
4454

4555
if __name__=="__main__":
4656
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")

0 commit comments

Comments
 (0)