Skip to content

Commit 9935d8d

Browse files
Multidevice Support (#102)
* Began adding multidevice tests. * Multidevice test is succeeding. * Avoid setting max shared memory for all cards; only do this for cards that match our cc. * HIP multidevice test passing.
1 parent b0d4445 commit 9935d8d

5 files changed

Lines changed: 134 additions & 53 deletions

File tree

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,18 @@ We do not (yet) support:
290290

291291
If you have a use case for any of the unsupported features above, let us know.
292292

293+
## Multidevice / Stream Support
294+
To use OpenEquivariance on multiple GPUs of a single
295+
compute node, we currently require that all GPUs
296+
share the same compute capability. This is because
297+
our kernels are compiled based on the shared memory
298+
capacity of the numerically first visible GPU card.
299+
On heterogeneous systems, you can still
300+
use OpenEquivariance on all GPUs that match the
301+
compute capability of the first visible device.
302+
303+
We are working on support for CUDA streams!
304+
293305
## Citation and Acknowledgements
294306
If you find this code useful, please cite our paper:
295307

openequivariance/extension/util/backend_cuda.hpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ class __attribute__((visibility("default"))) KernelLaunchConfig {
137137
smem(smem)
138138
{ }
139139

140-
141140
KernelLaunchConfig(int64_t num_blocks_i, int64_t num_threads_i, int64_t smem_i) :
142141
KernelLaunchConfig( static_cast<uint32_t>(num_blocks_i),
143142
static_cast<uint32_t>(num_threads_i),
@@ -156,8 +155,8 @@ class __attribute__((visibility("default"))) CUJITKernel {
156155

157156
bool compiled = false;
158157
char* code = nullptr;
158+
int cu_major, cu_minor;
159159

160-
CUdevice dev;
161160
CUlibrary library;
162161

163162
vector<string> kernel_names;
@@ -185,6 +184,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
185184
}
186185

187186
void compile(vector<string> kernel_names_i, vector<vector<int>> template_param_list, int opt_level=3) {
187+
DeviceProp dp(0); // We only query the first device on the system at the moment
188+
cu_major = dp.major;
189+
cu_minor = dp.minor;
190+
188191
if(compiled) {
189192
throw std::logic_error("JIT object has already been compiled!");
190193
}
@@ -215,8 +218,7 @@ class __attribute__((visibility("default"))) CUJITKernel {
215218

216219
}
217220

218-
DeviceProp dp(0); // TODO: We only query the first device at the moment
219-
std::string sm = "-arch=sm_" + std::to_string(dp.major) + std::to_string(dp.minor);
221+
std::string sm = "-arch=sm_" + std::to_string(cu_major) + std::to_string(cu_minor);
220222

221223
std::vector<const char*> opts = {
222224
"--std=c++17",
@@ -268,19 +270,29 @@ class __attribute__((visibility("default"))) CUJITKernel {
268270
kernels.emplace_back();
269271
CUDA_SAFE_CALL(cuLibraryGetKernel(&(kernels[i]), library, name));
270272
}
271-
272-
CUDA_SAFE_CALL(cuDeviceGet(&dev, 0));
273273
}
274274

275275
void set_max_smem(int kernel_id, uint32_t max_smem_bytes) {
276+
if(!compiled)
277+
throw std::logic_error("JIT object has not been compiled!");
276278
if(kernel_id >= kernels.size())
277279
throw std::logic_error("Kernel index out of range!");
278280

279-
CUDA_SAFE_CALL(cuKernelSetAttribute(
280-
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
281-
max_smem_bytes,
282-
kernels[kernel_id],
283-
dev));
281+
int device_count;
282+
CUDA_SAFE_CALL(cuDeviceGetCount(&device_count));
283+
284+
for(int i = 0; i < device_count; i++) {
285+
DeviceProp dp(i);
286+
if(dp.major == cu_major && dp.minor == cu_minor) {
287+
CUdevice dev;
288+
CUDA_SAFE_CALL(cuDeviceGet(&dev, i));
289+
CUDA_SAFE_CALL(cuKernelSetAttribute(
290+
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
291+
max_smem_bytes,
292+
kernels[kernel_id],
293+
dev));
294+
}
295+
}
284296
}
285297

286298
void execute(int kernel_id, void* args[], KernelLaunchConfig config) {
@@ -291,6 +303,10 @@ class __attribute__((visibility("default"))) CUJITKernel {
291303
CUDA_SAFE_CALL(cuCtxGetCurrent(&pctx));
292304

293305
if(pctx == NULL) {
306+
int device_id;
307+
CUdevice dev;
308+
CUDA_ERRCHK(cudaGetDevice(&device_id));
309+
CUDA_SAFE_CALL(cuDeviceGet(&dev, device_id));
294310
CUDA_SAFE_CALL(cuDevicePrimaryCtxRetain(&pctx, dev));
295311
CUDA_SAFE_CALL(cuCtxSetCurrent(pctx));
296312
}

openequivariance/extension/util/backend_hip.hpp

Lines changed: 50 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <vector>
77
#include <string>
88
#include <iostream>
9+
#include <memory>
910

1011
using namespace std;
1112

@@ -146,20 +147,54 @@ class __attribute__((visibility("default"))) KernelLaunchConfig {
146147
{ }
147148
};
148149

150+
class __attribute__((visibility("default"))) KernelLibrary {
151+
hipModule_t library;
152+
vector<hipFunction_t> kernels;
153+
154+
public:
155+
int device;
156+
KernelLibrary(hiprtcProgram &prog, vector<char> &kernel_binary, vector<string> &kernel_names) {
157+
hipGetDevice(&device);
158+
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));
159+
160+
for (size_t i = 0; i < kernel_names.size(); i++) {
161+
const char *name;
162+
163+
HIPRTC_SAFE_CALL(hiprtcGetLoweredName(
164+
prog,
165+
kernel_names[i].c_str(), // name expression
166+
&name // lowered name
167+
));
168+
169+
kernels.emplace_back();
170+
HIP_ERRCHK(hipModuleGetFunction(&(kernels[i]), library, name));
171+
}
172+
}
173+
174+
hipFunction_t operator[](int kernel_id) {
175+
if(kernel_id >= kernels.size())
176+
throw std::logic_error("Kernel index out of range!");
177+
178+
return kernels[kernel_id];
179+
}
180+
181+
~KernelLibrary() {
182+
HIP_ERRCHK(hipModuleUnload(library));
183+
}
184+
};
185+
149186
class __attribute__((visibility("default"))) HIPJITKernel {
150187
private:
151188
hiprtcProgram prog;
152-
153189
bool compiled = false;
154-
char* code = nullptr;
155-
156-
hipModule_t library;
157190

158191
vector<string> kernel_names;
159-
vector<hipFunction_t> kernels;
192+
unique_ptr<KernelLibrary> kernels;
160193

161194
public:
162195
string kernel_plaintext;
196+
vector<char> kernel_binary;
197+
163198
HIPJITKernel(string plaintext) :
164199
kernel_plaintext(plaintext) {
165200

@@ -247,42 +282,24 @@ class __attribute__((visibility("default"))) HIPJITKernel {
247282

248283
size_t codeSize;
249284
HIPRTC_SAFE_CALL(hiprtcGetCodeSize(prog, &codeSize));
250-
code = new char[codeSize];
251-
HIPRTC_SAFE_CALL(hiprtcGetCode(prog, code));
252-
253-
vector<char> kernel_binary(codeSize);
254-
hiprtcGetCode(prog, kernel_binary.data());
255-
256-
//HIP_SAFE_CALL(cuInit(0));
257-
HIP_ERRCHK(hipModuleLoadData(&library, kernel_binary.data()));
258-
259-
for (size_t i = 0; i < kernel_names.size(); i++) {
260-
const char *name;
261-
262-
HIPRTC_SAFE_CALL(hiprtcGetLoweredName(
263-
prog,
264-
kernel_names[i].c_str(), // name expression
265-
&name // lowered name
266-
));
267-
268-
kernels.emplace_back();
269-
HIP_ERRCHK(hipModuleGetFunction(&(kernels[i]), library, name));
270-
}
285+
kernel_binary.resize(codeSize);
286+
hiprtcGetCode(prog, kernel_binary.data());
287+
kernels.reset(new KernelLibrary(prog, kernel_binary, kernel_names));
271288
}
272289

273290
void set_max_smem(int kernel_id, uint32_t max_smem_bytes) {
274-
if(kernel_id >= kernels.size())
275-
throw std::logic_error("Kernel index out of range!");
276-
277-
// Ignore for AMD GPUs
291+
// Ignore for AMD GPUs
278292
}
279293

280294
void execute(int kernel_id, void* args[], KernelLaunchConfig config) {
281-
if(kernel_id >= kernels.size())
282-
throw std::logic_error("Kernel index out of range!");
295+
int device_id; HIP_ERRCHK(hipGetDevice(&device_id));
296+
if(device_id != kernels->device) {
297+
kernels.reset();
298+
kernels.reset(new KernelLibrary(prog, kernel_binary, kernel_names));
299+
}
283300

284301
HIP_ERRCHK(
285-
hipModuleLaunchKernel( (kernels[kernel_id]),
302+
hipModuleLaunchKernel( ((*kernels)[kernel_id]),
286303
config.num_blocks, 1, 1, // grid dim
287304
config.num_threads, 1, 1, // block dim
288305
config.smem, config.hStream, // shared mem and stream
@@ -291,10 +308,6 @@ class __attribute__((visibility("default"))) HIPJITKernel {
291308
}
292309

293310
~HIPJITKernel() {
294-
if(compiled) {
295-
HIP_ERRCHK(hipModuleUnload(library));
296-
delete[] code;
297-
}
298311
HIPRTC_SAFE_CALL(hiprtcDestroyProgram(&prog));
299312
}
300313
};

tests/export_test.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
import numpy as np
55
import openequivariance as oeq
66
from torch_geometric import EdgeIndex
7-
from openequivariance.implementations.TensorProduct import TensorProduct
8-
from openequivariance.benchmark.correctness_utils import correctness_forward, correctness_backward, correctness_double_backward
97

108
@pytest.fixture(scope='session')
119
def problem_and_irreps():
@@ -15,9 +13,6 @@ def problem_and_irreps():
1513
shared_weights=False, internal_weights=False,
1614
irrep_dtype=np.float32, weight_dtype=np.float32)
1715

18-
gen = torch.Generator(device='cuda')
19-
gen.manual_seed(0)
20-
2116
return problem, X_ir, Y_ir, Z_ir,
2217

2318
@pytest.fixture(params=['batch', 'conv_det', 'conv_atomic'], scope='session')

tests/multidevice_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import textwrap, torch, subprocess, os
2+
import numpy as np
3+
4+
def test_multidevice():
5+
result = subprocess.run([
6+
"python", "-m", "torch.distributed.run",
7+
"--standalone", "--nnodes=1", "--nproc-per-node=gpu",
8+
__file__],
9+
capture_output=True,
10+
check=False)
11+
12+
if result.returncode != 0:
13+
error_string = f'''
14+
Invocation: {' '.join(result.args)}
15+
Test failed with return code {result.returncode}.
16+
\nOutput:\n\n{result.stdout.decode()}
17+
\nError:\n\n{result.stderr.decode()}
18+
'''
19+
assert False, textwrap.dedent(error_string)
20+
21+
assert True
22+
23+
if __name__ == "__main__":
24+
import openequivariance as oeq
25+
26+
# Use MACE-large to test >64KB shared memory allocation
27+
from openequivariance.benchmark.benchmark_configs import mace_problems
28+
problem = mace_problems[0]
29+
30+
local_rank = int(os.environ["LOCAL_RANK"])
31+
device = f'cuda:{local_rank}'
32+
torch.set_default_device(device)
33+
34+
X_ir, Y_ir, Z_ir = problem.irreps_in1, problem.irreps_in2, problem.irreps_out
35+
tp = oeq.TensorProduct(problem)
36+
37+
batch_size = 1000
38+
gen = torch.Generator(device=device)
39+
gen.manual_seed(0)
40+
X = torch.rand(batch_size, X_ir.dim, device=device, generator=gen)
41+
Y = torch.rand(batch_size, Y_ir.dim, device=device, generator=gen)
42+
W = torch.rand(batch_size, problem.weight_numel, device=device, generator=gen)
43+
44+
with torch.cuda.device(device):
45+
result = tp.forward(X, Y, W)

0 commit comments

Comments
 (0)