Skip to content

Commit b9c9135

Browse files
committed
Finished prototype of TensorProductConv.
1 parent 865ca13 commit b9c9135

4 files changed

Lines changed: 70 additions & 11 deletions

File tree

openequivariance/openequivariance/core/utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
import json
99
import tempfile
10+
import hashlib
1011
from openequivariance.impl_torch.extlib import GPUTimer
1112

12-
1313
def sparse_outer_product_work(cg: np.ndarray) -> int:
1414
return np.sum(np.max(cg != 0, axis=2))
1515

@@ -170,3 +170,13 @@ def benchmark(func, num_warmup, num_iter, mode="gpu_time", kernel_names=[]):
170170
time_millis[i] = kernel_time
171171

172172
return time_millis
173+
174+
175+
def hash_attributes(attrs):
176+
m = hashlib.sha256()
177+
178+
for key in sorted(attrs.keys()):
179+
m.update(attrs[key].__repr__().encode("utf-8"))
180+
181+
hash = int(m.hexdigest()[:16], 16) >> 1
182+
attrs["hash"] = hash

openequivariance/openequivariance/impl_jax/TensorProduct.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,9 @@
77
import hashlib
88
from openequivariance.core.e3nn_lite import TPProblem, Irreps
99
from openequivariance.core.LoopUnrollTP import LoopUnrollTP
10+
from openequivariance.core.utils import hash_attributes
1011
import jax.numpy as jnp
1112

12-
def hash_attributes(attrs):
13-
m = hashlib.sha256()
14-
15-
for key in sorted(attrs.keys()):
16-
m.update(attrs[key].__repr__().encode("utf-8"))
17-
18-
hash = int(m.hexdigest()[:16], 16) >> 1
19-
attrs["hash"] = hash
20-
2113
@partial(jax.custom_vjp, nondiff_argnums=(3,4,5))
2214
def forward(X, Y, W, L3_dim, irrep_dtype, attrs):
2315
forward_call = jax.ffi.ffi_call("tp_forward",
@@ -84,7 +76,6 @@ def jax_to_torch(x):
8476
return torch.tensor(np.asarray(x), requires_grad=True)
8577

8678
if __name__ == "__main__":
87-
tp_problem = None
8879
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")
8980
instructions=[(0, 0, 0, "uvu", True)]
9081
problem = TPProblem(X_ir, Y_ir, Z_ir,
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
from functools import partial
3+
from openequivariance.impl_jax import extlib
4+
5+
from openequivariance.core.e3nn_lite import TPProblem, Irreps
6+
from openequivariance.core.LoopUnrollConv import LoopUnrollConv
7+
from openequivariance.core.utils import hash_attributes
8+
9+
import jax
10+
import jax.numpy as jnp
11+
12+
from openequivariance.benchmark.logging_utils import getLogger
13+
logger = getLogger()
14+
15+
class TensorProductConv(LoopUnrollConv):
16+
def __init__(self, config, deterministic=False, kahan=False):
17+
dp = extlib.DeviceProp(0)
18+
super().__init__(
19+
self,
20+
config,
21+
dp, extlib.postprocess_kernel,
22+
idx_dtype=np.int64,
23+
torch_op=False,
24+
deterministic=deterministic,
25+
kahan=kahan
26+
)
27+
28+
self.attrs = {
29+
"kernel": self.jit_kernel,
30+
"forward_config": vars(self.forward_schedule.launch_config),
31+
"backward_config": vars(self.backward_schedule.launch_config),
32+
"double_backward_config": vars(self.double_backward_schedule.launch_config),
33+
"kernel_prop": self.kernelProp
34+
}
35+
hash_attributes(self.attrs)
36+
37+
self.weight_numel = config.weight_numel
38+
self.L3_dim = self.config.irreps_out.dim
39+
40+
self.workspace = jnp.zeros((self.workspace_size,), dtype=jnp.uint8)
41+
logger.info(f"Convolution requires {self.workspace_size // (2 ** 20)}MB of workspace.")
42+
self.dummy_transpose_perm = jnp.zeros((1,), dtype=jnp.int64)
43+
44+
45+
if __name__=="__main__":
46+
X_ir, Y_ir, Z_ir = Irreps("1x2e"), Irreps("1x3e"), Irreps("1x2e")
47+
instructions=[(0, 0, 0, "uvu", True)]
48+
problem = TPProblem(X_ir, Y_ir, Z_ir,
49+
instructions,
50+
shared_weights=False,
51+
internal_weights=False)
52+
53+
conv = TensorProductConv(problem, deterministic=False, kahan=False)
54+
print("COMPLETE!")

openequivariance/openequivariance/impl_jax/extlib/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import jax
2+
import hashlib
23

34
def postprocess_kernel(kernel):
5+
'''
6+
Only CUDA for now, so no postprocessing.
7+
'''
48
return kernel
59

610
import openequivariance_extjax as oeq_extjax

0 commit comments

Comments
 (0)