Skip to content

Commit f1f453f

Browse files
asgloverAustin Glover
andauthored
check compatibility before compile attempt, better error message (#163)
Co-authored-by: Austin Glover <austin_glover@berkeley.com>
1 parent 08c891c commit f1f453f

1 file changed

Lines changed: 27 additions & 0 deletions

File tree

openequivariance/extension/util/backend_cuda.hpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <string>
77
#include <iostream>
88
#include <vector>
9+
#include <algorithm>
910

1011
using namespace std;
1112
using Stream = cudaStream_t;
@@ -160,13 +161,24 @@ class __attribute__((visibility("default"))) CUJITKernel {
160161

161162
CUlibrary library;
162163

164+
vector<int> supported_archs;
165+
163166
vector<string> kernel_names;
164167
vector<CUkernel> kernels;
165168

166169
public:
167170
string kernel_plaintext;
168171
CUJITKernel(string plaintext) :
169172
kernel_plaintext(plaintext) {
173+
174+
int num_supported_archs;
175+
NVRTC_SAFE_CALL(
176+
nvrtcGetNumSupportedArchs(&num_supported_archs));
177+
178+
supported_archs.resize(num_supported_archs);
179+
NVRTC_SAFE_CALL(
180+
nvrtcGetSupportedArchs(supported_archs.data()));
181+
170182

171183
NVRTC_SAFE_CALL(
172184
nvrtcCreateProgram( &prog, // prog
@@ -196,6 +208,21 @@ class __attribute__((visibility("default"))) CUJITKernel {
196208
throw std::logic_error("Kernel names and template parameters must have the same size!");
197209
}
198210

211+
int device_arch = cu_major * 10 + cu_minor;
212+
if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()){
213+
int nvrtc_version_major, nvrtc_version_minor;
214+
NVRTC_SAFE_CALL(
215+
nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor));
216+
217+
throw std::runtime_error("NVRTC version "
218+
+ std::to_string(nvrtc_version_major)
219+
+ "."
220+
+ std::to_string(nvrtc_version_minor)
221+
+ " does not support device architecture "
222+
+ std::to_string(device_arch)
223+
);
224+
}
225+
199226
for(unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) {
200227
string kernel_name = kernel_names_i[kernel];
201228
vector<int> &template_params = template_param_list[kernel];

0 commit comments

Comments
 (0)