Skip to content

Commit 6be2be5

Browse files
Completed AMD HIP Support for UVW Tensor Products (#94)
* Many configurations are working, but we are failing some simple tests. * Updated benchmark.py script to remove keyword parameters.
1 parent 6f288d6 commit 6be2be5

17 files changed

Lines changed: 83 additions & 64 deletions

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ which has a closed-source kernel package. We also offer fused
2323
equivariant graph convolutions that can reduce
2424
computation and memory consumption significantly.
2525

26-
We currently support NVIDIA GPUs and have just added beta support on AMD GPUs for
27-
UVU tensor products! See [the coverage table](#tensor-products-we-accelerate) for more
26+
We currently support NVIDIA GPUs and just added beta support on AMD GPUs for
27+
all tensor products! See [the coverage table](#tensor-products-we-accelerate) for more
2828
details.
2929

3030
**Warning**: This is an early release, bug reports are welcome.
@@ -242,9 +242,9 @@ python tests/mace_driver.py carbon.xyz -o outputs/mace_tests -i e3nn cue oeq
242242
| Operation | CUDA | HIP |
243243
|--------------------------|----------|-----|
244244
| UVU Batch |||
245-
| UVW Batch || 🚧🔨 |
245+
| UVW Batch || |
246246
| UVU Convolution |||
247-
| UVW Convolution || 🚧🔨 |
247+
| UVW Convolution || |
248248
| Symmetric Tensor Product | ✅ (beta) | 🚧🔨 |
249249

250250
e3nn supports a variety of connection modes for CG tensor products. We support

openequivariance/extension/convolution.hpp

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,18 +96,28 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
9696
public:
9797
JIT_IMPL jit;
9898
KernelLaunchConfig forward_config;
99-
KernelLaunchConfig backward_config;
99+
KernelLaunchConfig backward_config;
100+
bool is_uvw;
100101

101102
JITConvImpl(
102103
std::string jit_kernel,
103104
KernelLaunchConfig forward_config_i,
104-
KernelLaunchConfig backward_config_i) :
105+
KernelLaunchConfig backward_config_i,
106+
bool is_uvw_i) :
105107
jit(jit_kernel),
106108
forward_config(forward_config_i),
107-
backward_config(backward_config_i) {
109+
backward_config(backward_config_i),
110+
is_uvw(is_uvw_i) {
108111

109112
vector<string> kernels = {"forward", "backward", "fixup_forward", "fixup_backward"};
110-
jit.compile(kernels, {{}, {}, {}, {}});
113+
114+
int opt_level = 3;
115+
#ifdef HIP_BACKEND
116+
if(is_uvw) {
117+
opt_level = 1;
118+
}
119+
#endif
120+
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
111121

112122
if(forward_config.smem > 0) {
113123
jit.set_max_smem(0, forward_config.smem);
@@ -134,7 +144,8 @@ class __attribute__ ((visibility ("default"))) JITConvImpl : public ConvolutionI
134144
bwd_dict["num_blocks"],
135145
bwd_dict["num_threads"],
136146
bwd_dict["smem"]
137-
)) { }
147+
),
148+
kernel_dims["is_uvw"] == 1) { }
138149

139150
void exec_conv(
140151
void* L1_in,

openequivariance/extension/tensorproducts.hpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,18 +54,28 @@ class __attribute__ ((visibility ("default"))) JITTPImpl : public GenericTensorP
5454
public:
5555
JIT_IMPL jit;
5656
KernelLaunchConfig forward_config, backward_config, double_backward_config;
57+
bool is_uvw;
5758

5859
JITTPImpl(
5960
std::string jit_kernel,
6061
KernelLaunchConfig forward_config_i,
6162
KernelLaunchConfig backward_config_i,
62-
KernelLaunchConfig double_backward_config_i) :
63+
KernelLaunchConfig double_backward_config_i,
64+
bool is_uvw_i) :
6365
jit(jit_kernel),
6466
forward_config(forward_config_i),
6567
backward_config(backward_config_i),
66-
double_backward_config(double_backward_config_i) {
68+
double_backward_config(double_backward_config_i),
69+
is_uvw(is_uvw_i) {
6770
vector<string> kernels = {"forward", "backward", "double_backward_A", "double_backward_B"};
68-
jit.compile(kernels, {{}, {}, {}, {}});
71+
72+
int opt_level = 3;
73+
#ifdef HIP_BACKEND
74+
if(is_uvw) {
75+
opt_level = 1;
76+
}
77+
#endif
78+
jit.compile(kernels, {{}, {}, {}, {}}, opt_level);
6979

7080
if(forward_config.smem > 0) {
7181
jit.set_max_smem(0, forward_config.smem);
@@ -103,7 +113,8 @@ class __attribute__ ((visibility ("default"))) JITTPImpl : public GenericTensorP
103113
dbl_bwd_dict["num_blocks"],
104114
dbl_bwd_dict["num_threads"],
105115
dbl_bwd_dict["smem"]
106-
)
116+
),
117+
kernel_dims["is_uvw"] == 1
107118
) { }
108119

109120
void exec_tensor_product(

openequivariance/extension/torch_tp_jit.cpp

Lines changed: 16 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,14 @@ namespace py=pybind11;
3737

3838
using Map_t=torch::Dict<string, int64_t>;
3939

40+
std::unordered_map<string, int64_t> to_map(const Map_t &map) {
41+
std::unordered_map<string, int64_t> result;
42+
for(auto it = map.begin(); it != map.end(); ++it) {
43+
result[it->key()] = it->value();
44+
}
45+
return result;
46+
}
47+
4048
inline void* data_ptr(const torch::Tensor &tensor) {
4149
if(tensor.dtype() == torch::kFloat)
4250
return reinterpret_cast<void*>(tensor.data_ptr<float>());
@@ -62,21 +70,11 @@ class __attribute__ ((visibility ("default"))) TorchJITProduct : public torch::C
6270
dbl_bwd_dict(dbl_bwd_dict_i.copy()),
6371
kernel_dims(kernel_dims_i.copy()),
6472
internal(kernel_plaintext,
65-
KernelLaunchConfig(
66-
fwd_dict.at("num_blocks"),
67-
fwd_dict.at("num_threads"),
68-
fwd_dict.at("smem")
73+
to_map(fwd_dict_i),
74+
to_map(bwd_dict_i),
75+
to_map(dbl_bwd_dict_i),
76+
to_map(kernel_dims_i)
6977
),
70-
KernelLaunchConfig(
71-
bwd_dict.at("num_blocks"),
72-
bwd_dict.at("num_threads"),
73-
bwd_dict.at("smem")
74-
),
75-
KernelLaunchConfig(
76-
dbl_bwd_dict.at("num_blocks"),
77-
dbl_bwd_dict.at("num_threads"),
78-
dbl_bwd_dict.at("smem")
79-
)),
8078
L3_dim(kernel_dims.at("L3_dim")),
8179
shared_weights(kernel_dims.at("shared_weights")) { }
8280

@@ -225,17 +223,11 @@ class TorchJITConv : public torch::CustomClassHolder {
225223
fwd_dict(fwd_dict_i.copy()),
226224
bwd_dict(bwd_dict_i.copy()),
227225
kernel_dims(kernel_dims_i.copy()),
228-
internal(kernel_plaintext,
229-
KernelLaunchConfig(
230-
fwd_dict.at("num_blocks"),
231-
fwd_dict.at("num_threads"),
232-
fwd_dict.at("smem")
226+
internal(kernel_plaintext,
227+
to_map(fwd_dict_i),
228+
to_map(bwd_dict_i),
229+
to_map(kernel_dims_i)
233230
),
234-
KernelLaunchConfig(
235-
bwd_dict.at("num_blocks"),
236-
bwd_dict.at("num_threads"),
237-
bwd_dict.at("smem")
238-
)),
239231
L3_dim(kernel_dims.at("L3_dim")) { }
240232

241233
tuple<tuple<string, string>, tuple<string, Map_t>, tuple<string, Map_t>, tuple<string, Map_t>> __obj_flatten__() {

openequivariance/extension/util/backend_cuda.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,13 +178,13 @@ class __attribute__((visibility("default"))) CUJITKernel {
178178
NULL)); // includeNames
179179
}
180180

181-
void compile(string kernel_name, const vector<int> template_params) {
181+
void compile(string kernel_name, const vector<int> template_params, int opt_level=3) {
182182
vector<string> kernel_names = {kernel_name};
183183
vector<vector<int>> template_param_list = {template_params};
184184
compile(kernel_names, template_param_list);
185185
}
186186

187-
void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list) {
187+
void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list, int opt_level=3) {
188188
if(compiled) {
189189
throw std::logic_error("JIT object has already been compiled!");
190190
}

openequivariance/extension/util/backend_hip.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,13 @@ class __attribute__((visibility("default"))) HIPJITKernel {
173173
NULL)); // includeNames
174174
}
175175

176-
void compile(string kernel_name, const vector<int> template_params) {
176+
void compile(string kernel_name, const vector<int> template_params, int opt_level=3) {
177177
vector<string> kernel_names = {kernel_name};
178178
vector<vector<int>> template_param_list = {template_params};
179-
compile(kernel_names, template_param_list);
179+
compile(kernel_names, template_param_list, opt_level);
180180
}
181181

182-
void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list) {
182+
void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list, int opt_level=3) {
183183
if(compiled) {
184184
throw std::logic_error("JIT object has already been compiled!");
185185
}
@@ -214,9 +214,11 @@ class __attribute__((visibility("default"))) HIPJITKernel {
214214
int device = 0;
215215
HIP_ERRCHK(hipGetDeviceProperties(&props, device));
216216
std::string sarg = std::string("--gpu-architecture=") + props.gcnArchName;
217+
std::string opt_arg = "-O" + std::to_string(opt_level);
217218

218219
std::vector<const char*> opts = {
219220
"--std=c++17",
221+
opt_arg.c_str(),
220222
sarg.c_str()
221223
};
222224

openequivariance/implementations/ComputationSchedule.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from openequivariance.implementations.TensorProductBase import *
66
logger = getLogger()
77

8-
# This class assumes a warp size of 32
9-
108
class IrrepMapping:
119
'''
1210
Maps irreps from a source to a destination set.
@@ -265,9 +263,13 @@ def __init__(self,
265263
# Stream weights on the fly before pre-loading
266264
self.stream_weights = stream_weights
267265

268-
# Step 1: Break the irreps and the instructions into chunks of at most 32 x 32 x 32.
266+
# Step 1: Break the irreps and the instructions into chunks
267+
268+
chunk_size = warp_size
269+
if include_scratch: # There is at least one UVW computation if this flag is set. Cap the chunk size to 32.
270+
chunk_size = 32
269271

270-
self.problem_splitter = ProblemSplitter(config, warp_size)
272+
self.problem_splitter = ProblemSplitter(config, chunk_size)
271273
self.updated_config = self.problem_splitter.output
272274
self.L1, self.L2, self.L3 = self.updated_config.irreps_in1, self.updated_config.irreps_in2, self.updated_config.irreps_out
273275
self.new_instructions = self.updated_config.instructions

openequivariance/implementations/LoopUnrollTP.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def generate_double_backward_schedule(warps_per_block):
6767
generate_schedule(warp_count)
6868
break
6969
except Exception as e:
70-
warp_count //= 2
70+
warp_count -= 2
7171
if warp_count == 0:
7272
raise RuntimeError("Tensor product schedule generation failed, shared memory inadequate!")
7373

@@ -76,8 +76,8 @@ def generate_double_backward_schedule(warps_per_block):
7676
backward_schedule=self.backward_schedule,
7777
double_backward_schedule=self.double_backward_schedule))
7878

79-
with open("scratch.txt", "w") as f:
80-
f.write(self.jit_kernel)
79+
#with open("scratch.txt", "w") as f:
80+
# f.write(self.jit_kernel)
8181

8282
internal_cls = None
8383
if self.torch_op and extlib.TORCH_COMPILE:
@@ -94,7 +94,8 @@ def generate_double_backward_schedule(warps_per_block):
9494
vars(self.backward_schedule.launch_config),
9595
vars(self.double_backward_schedule.launch_config),
9696
{"L3_dim": self.L3.dim,
97-
"shared_weights": int(self.config.shared_weights)})
97+
"shared_weights": int(self.config.shared_weights),
98+
"is_uvw": int(self.is_uvw)})
9899
logger.info("Kernel compiled!")
99100

100101
logger.info(f"Kernel File Size: {len(self.jit_kernel) // 1024} KB")

openequivariance/implementations/convolution/LoopUnrollConv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def generate_backward_schedule(warps_per_block):
115115
self.internal = internal_cls(self.jit_kernel,
116116
vars(self.forward_schedule.launch_config),
117117
vars(self.backward_schedule.launch_config),
118-
{"L3_dim": self.L3.dim})
118+
{"L3_dim": self.L3.dim,
119+
"is_uvw": int(self.is_uvw)})
119120
logger.info("Kernel compiled!")
120121

121122
self.reorder_weights_e3nn_to_oeq = lambda input, output, has_batch_dim: \

openequivariance/templates/loop_unroll_batch.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ using IRREP_T = {{ forward_schedule.irrep_dtype_cstr }};
1818
using WEIGHT_T = {{ forward_schedule.weight_dtype_cstr }};
1919

2020
{%- for i, segment in enumerate(forward_schedule.segments) %}
21-
{{ generate_segment_kernel_forward(i, segment) }}
21+
{{ generate_segment_kernel_forward(i, segment, forward_schedule.launch_config.warp_size) }}
2222
{%- endfor %}
2323

2424
__global__ void forward(

0 commit comments

Comments
 (0)