Skip to content

Commit 5f33074

Browse files
Removes TorchBind Custom Classes (#183)
* Refactoring PyTorch tensor product. * Tensor product refactored successfully. * Eliminated opaque bindings, batch tests working. * Ready to run convolution tests. * Ready to fix fakes. * DeviceBuffer is gone. * Fixed the fakes. * Fixing compile tests. * Fixed JAX error. * Fixed compile issues. * Ruff. * Eliminated raw pointer functions and generic classes. * Ruff. * Fixed benchmark.py. * Updated JAX extension version due to backend change.
1 parent 72122ce commit 5f33074

21 files changed

Lines changed: 582 additions & 1512 deletions

openequivariance/openequivariance/_torch/TensorProduct.py

Lines changed: 99 additions & 314 deletions
Large diffs are not rendered by default.

openequivariance/openequivariance/_torch/TensorProductConv.py

Lines changed: 217 additions & 541 deletions
Large diffs are not rendered by default.

openequivariance/openequivariance/_torch/extlib/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151

5252
extra_cflags = ["-O3"]
5353
generic_sources = ["generic_module.cpp"]
54-
torch_sources = ["libtorch_tp_jit.cpp"]
54+
torch_sources = ["libtorch_tp_jit.cpp", "json11/json11.cpp"]
5555

5656
include_dirs, extra_link_args = (["util"], ["-Wl,--no-as-needed"])
5757

@@ -149,22 +149,13 @@ def torch_ext_so_path():
149149

150150
if BUILT_EXTENSION:
151151
from generic_module import (
152-
JITTPImpl,
153-
JITConvImpl,
154152
GroupMM_F32,
155153
GroupMM_F64,
156154
DeviceProp,
157-
DeviceBuffer,
158155
GPUTimer,
159156
)
160157
else:
161158

162-
def JITTPImpl(*args, **kwargs):
163-
_raise_import_error_helper("JITTPImpl")
164-
165-
def JITConvImpl(*args, **kwargs):
166-
_raise_import_error_helper("JITConvImpl")
167-
168159
def GroupMM_F32(*args, **kwargs):
169160
_raise_import_error_helper("GroupMM_F32")
170161

@@ -174,8 +165,5 @@ def GroupMM_F64(*args, **kwargs):
174165
def DeviceProp(*args, **kwargs):
175166
_raise_import_error_helper("DeviceProp")
176167

177-
def DeviceBuffer(*args, **kwargs):
178-
_raise_import_error_helper("DeviceBuffer")
179-
180168
def GPUTimer(*args, **kwargs):
181169
_raise_import_error_helper("GPUTimer")

openequivariance/openequivariance/_torch/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import numpy as np
23
from types import MappingProxyType
34
from openequivariance.core.utils import DTypeEnum
45

@@ -66,3 +67,11 @@ def reorder_torch(schedule, weights_in, direction, has_batch_dim):
6667
DTypeEnum.UINT8: torch.uint8,
6768
}
6869
)
70+
71+
72+
def string_to_tensor(text: str) -> torch.Tensor:
73+
bytes_data = text.encode("utf-8")
74+
np_bytes = np.frombuffer(bytes_data, dtype=np.uint8)
75+
result = torch.tensor(np_bytes, device="cpu")
76+
result.requires_grad = False
77+
return result

openequivariance/openequivariance/benchmark/ConvBenchmarkSuite.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import openequivariance as oeq
99
from openequivariance.benchmark.logging_utils import getLogger
1010
from openequivariance.core.ConvolutionBase import CoordGraph
11+
from openequivariance.benchmark.benchmark_utils import NpEncoder
1112

1213
logger = getLogger()
1314

@@ -145,7 +146,7 @@ def run(
145146
f"{output_folder}/{self.exp_count}_{impl.name()}_{graph.name}.json"
146147
)
147148
with open(fname, "w") as f:
148-
json.dump(result, f, indent=2)
149+
json.dump(result, f, indent=2, cls=NpEncoder)
149150
self.exp_count += 1
150151

151152
logger.info(f"Finished {tc_name}, graph {graph.name}")

openequivariance/openequivariance/benchmark/TestBenchmarkSuite.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
benchmark_forward,
2222
benchmark_backward,
2323
benchmark_double_backward,
24+
NpEncoder,
2425
)
2526

2627
logger = getLogger()
@@ -235,10 +236,12 @@ def run(
235236

236237
fname = pathlib.Path(f"{output_folder}/{test_ID}_{impl.name()}.json")
237238

238-
pretty_result = json.dumps(obj=result, indent=2).replace("\\n", "\n")
239+
pretty_result = json.dumps(obj=result, indent=2, cls=NpEncoder).replace(
240+
"\\n", "\n"
241+
)
239242
logger.debug(pretty_result)
240243
with open(fname, "w") as f:
241-
json.dump(result, f, indent=2)
244+
json.dump(result, f, indent=2, cls=NpEncoder)
242245

243246
self.results.append(result)
244247
logger.info(f"Finished Test ID: {test_ID}")

openequivariance/openequivariance/benchmark/benchmark_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import numpy as np
23

34
from openequivariance.benchmark.random_buffer_utils import (
@@ -290,3 +291,14 @@ def benchmark_double_backward(
290291
)
291292

292293
return result
294+
295+
296+
class NpEncoder(json.JSONEncoder):
297+
def default(self, obj):
298+
if isinstance(obj, np.integer):
299+
return int(obj)
300+
if isinstance(obj, np.floating):
301+
return float(obj)
302+
if isinstance(obj, np.ndarray):
303+
return obj.tolist()
304+
return super(NpEncoder, self).default(obj)

openequivariance/openequivariance/core/LoopUnrollConv.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import numpy as np
2+
import json
23

34
from openequivariance.core.ConvolutionBase import ConvolutionBase
45
from openequivariance.core.ComputationSchedule import (
56
ComputationSchedule,
67
SMEMCapacityException,
78
)
89

9-
from openequivariance.core.utils import dtype_to_enum
1010
from openequivariance.templates.jinja_utils import get_jinja_environment
11-
from openequivariance.core.utils import filter_and_analyze_problem
11+
from openequivariance.core.utils import (
12+
filter_and_analyze_problem,
13+
dtype_to_enum,
14+
hash_str_64,
15+
)
1216

1317

1418
class LoopUnrollConv(ConvolutionBase):
@@ -203,5 +207,15 @@ def generate_double_backward_schedule(warps_per_block):
203207
)
204208
self.jit_kernel = postprocess_kernel(self.jit_kernel)
205209

206-
# with open("scratch.txt", "w") as f:
207-
# f.write(self.jit_kernel)
210+
self.kernel_string = json.dumps(
211+
{
212+
"kernel": self.jit_kernel,
213+
"forward_config": vars(self.forward_schedule.launch_config),
214+
"backward_config": vars(self.backward_schedule.launch_config),
215+
"double_backward_config": vars(
216+
self.double_backward_schedule.launch_config
217+
),
218+
"kernel_prop": self.kernel_prop,
219+
}
220+
)
221+
self.hash = hash_str_64(self.kernel_string)

openequivariance/openequivariance/core/LoopUnrollTP.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import numpy as np
2+
import json
23

34
from openequivariance.templates.jinja_utils import get_jinja_environment
45
from openequivariance.core.ComputationSchedule import ComputationSchedule
56
from openequivariance.core.TensorProductBase import TensorProductBase
6-
from openequivariance.core.utils import dtype_to_enum
7+
from openequivariance.benchmark.logging_utils import getLogger
8+
from openequivariance.core.utils import dtype_to_enum, hash_str_64
79

810
from openequivariance.core.utils import (
911
filter_and_analyze_problem,
1012
count_cg_non_zero,
1113
)
1214

15+
logger = getLogger()
16+
1317

1418
class LoopUnrollTP(TensorProductBase):
1519
def __init__(self, config, dp, postprocess_kernel, torch_op):
@@ -91,7 +95,7 @@ def generate_double_backward_schedule(warps_per_block):
9195
)
9296
)
9397

94-
self.kernelProp = {
98+
self.kernel_prop = {
9599
"L1_dim": self.L1.dim,
96100
"L2_dim": self.L2.dim,
97101
"L3_dim": self.L3.dim,
@@ -106,6 +110,20 @@ def generate_double_backward_schedule(warps_per_block):
106110
"idx_dtype": 0,
107111
}
108112

113+
self.kernel_string = json.dumps(
114+
{
115+
"kernel": self.jit_kernel,
116+
"forward_config": vars(self.forward_schedule.launch_config),
117+
"backward_config": vars(self.backward_schedule.launch_config),
118+
"double_backward_config": vars(
119+
self.double_backward_schedule.launch_config
120+
),
121+
"kernel_prop": self.kernel_prop,
122+
}
123+
)
124+
self.hash = hash_str_64(self.kernel_string)
125+
logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")
126+
109127
def calculate_flops_forward(self, batch_size: int) -> dict:
110128
if self.is_uvw:
111129
return super().calculate_flops_forward(batch_size)

openequivariance/openequivariance/core/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
import json
99
import tempfile
10+
import hashlib
1011

1112
from enum import IntEnum
12-
import hashlib
1313

1414

1515
class DTypeEnum(IntEnum):

0 commit comments

Comments
 (0)