|
6 | 6 | #include <string> |
7 | 7 | #include <iostream> |
8 | 8 | #include <vector> |
| 9 | +#include <algorithm> |
9 | 10 |
|
10 | 11 | using namespace std; |
11 | 12 | using Stream = cudaStream_t; |
@@ -160,13 +161,24 @@ class __attribute__((visibility("default"))) CUJITKernel { |
160 | 161 |
|
161 | 162 | CUlibrary library; |
162 | 163 |
|
| 164 | + vector<int> supported_archs; |
| 165 | + |
163 | 166 | vector<string> kernel_names; |
164 | 167 | vector<CUkernel> kernels; |
165 | 168 |
|
166 | 169 | public: |
167 | 170 | string kernel_plaintext; |
168 | 171 | CUJITKernel(string plaintext) : |
169 | 172 | kernel_plaintext(plaintext) { |
| 173 | + |
| 174 | + int num_supported_archs; |
| 175 | + NVRTC_SAFE_CALL( |
| 176 | + nvrtcGetNumSupportedArchs(&num_supported_archs)); |
| 177 | + |
| 178 | + supported_archs.resize(num_supported_archs); |
| 179 | + NVRTC_SAFE_CALL( |
| 180 | + nvrtcGetSupportedArchs(supported_archs.data())); |
| 181 | + |
170 | 182 |
|
171 | 183 | NVRTC_SAFE_CALL( |
172 | 184 | nvrtcCreateProgram( &prog, // prog |
@@ -196,6 +208,21 @@ class __attribute__((visibility("default"))) CUJITKernel { |
196 | 208 | throw std::logic_error("Kernel names and template parameters must have the same size!"); |
197 | 209 | } |
198 | 210 |
|
| 211 | + int device_arch = cu_major * 10 + cu_minor; |
| 212 | + if (std::find(supported_archs.begin(), supported_archs.end(), device_arch) == supported_archs.end()){ |
| 213 | + int nvrtc_version_major, nvrtc_version_minor; |
| 214 | + NVRTC_SAFE_CALL( |
| 215 | + nvrtcVersion(&nvrtc_version_major, &nvrtc_version_minor)); |
| 216 | + |
| 217 | + throw std::runtime_error("NVRTC version " |
| 218 | + + std::to_string(nvrtc_version_major) |
| 219 | + + "." |
| 220 | + + std::to_string(nvrtc_version_minor) |
| 221 | + + " does not support device architecture " |
| 222 | + + std::to_string(device_arch) |
| 223 | + ); |
| 224 | + } |
| 225 | + |
199 | 226 | for(unsigned int kernel = 0; kernel < kernel_names_i.size(); kernel++) { |
200 | 227 | string kernel_name = kernel_names_i[kernel]; |
201 | 228 | vector<int> &template_params = template_param_list[kernel]; |
|
0 commit comments