66#include < vector>
77#include < string>
88#include < iostream>
9+ #include < memory>
910
1011using namespace std ;
1112
@@ -146,20 +147,54 @@ class __attribute__((visibility("default"))) KernelLaunchConfig {
146147 { }
147148};
148149
150+ class __attribute__ ((visibility(" default" ))) KernelLibrary {
151+ hipModule_t library;
152+ vector<hipFunction_t> kernels;
153+
154+ public:
155+ int device;
156+ KernelLibrary (hiprtcProgram &prog, vector<char > &kernel_binary, vector<string> &kernel_names) {
157+ hipGetDevice (&device);
158+ HIP_ERRCHK (hipModuleLoadData (&library, kernel_binary.data ()));
159+
160+ for (size_t i = 0 ; i < kernel_names.size (); i++) {
161+ const char *name;
162+
163+ HIPRTC_SAFE_CALL (hiprtcGetLoweredName (
164+ prog,
165+ kernel_names[i].c_str (), // name expression
166+ &name // lowered name
167+ ));
168+
169+ kernels.emplace_back ();
170+ HIP_ERRCHK (hipModuleGetFunction (&(kernels[i]), library, name));
171+ }
172+ }
173+
174+ hipFunction_t operator [](int kernel_id) {
175+ if (kernel_id >= kernels.size ())
176+ throw std::logic_error (" Kernel index out of range!" );
177+
178+ return kernels[kernel_id];
179+ }
180+
181+ ~KernelLibrary () {
182+ HIP_ERRCHK (hipModuleUnload (library));
183+ }
184+ };
185+
149186class __attribute__ ((visibility(" default" ))) HIPJITKernel {
150187private:
151188 hiprtcProgram prog;
152-
153189 bool compiled = false ;
154- char * code = nullptr ;
155-
156- hipModule_t library;
157190
158191 vector<string> kernel_names;
159- vector<hipFunction_t > kernels;
192+ unique_ptr<KernelLibrary > kernels;
160193
161194public:
162195 string kernel_plaintext;
196+ vector<char > kernel_binary;
197+
163198 HIPJITKernel (string plaintext) :
164199 kernel_plaintext (plaintext) {
165200
@@ -247,42 +282,24 @@ class __attribute__((visibility("default"))) HIPJITKernel {
247282
248283 size_t codeSize;
249284 HIPRTC_SAFE_CALL (hiprtcGetCodeSize (prog, &codeSize));
250- code = new char [codeSize];
251- HIPRTC_SAFE_CALL (hiprtcGetCode (prog, code));
252-
253- vector<char > kernel_binary (codeSize);
254- hiprtcGetCode (prog, kernel_binary.data ());
255-
256- // HIP_SAFE_CALL(cuInit(0));
257- HIP_ERRCHK (hipModuleLoadData (&library, kernel_binary.data ()));
258-
259- for (size_t i = 0 ; i < kernel_names.size (); i++) {
260- const char *name;
261-
262- HIPRTC_SAFE_CALL (hiprtcGetLoweredName (
263- prog,
264- kernel_names[i].c_str (), // name expression
265- &name // lowered name
266- ));
267-
268- kernels.emplace_back ();
269- HIP_ERRCHK (hipModuleGetFunction (&(kernels[i]), library, name));
270- }
285+ kernel_binary.resize (codeSize);
286+ hiprtcGetCode (prog, kernel_binary.data ());
287+ kernels.reset (new KernelLibrary (prog, kernel_binary, kernel_names));
271288 }
272289
273290 void set_max_smem (int kernel_id, uint32_t max_smem_bytes) {
274- if (kernel_id >= kernels.size ())
275- throw std::logic_error (" Kernel index out of range!" );
276-
277- // Ignore for AMD GPUs
291+ // Ignore for AMD GPUs
278292 }
279293
280294 void execute (int kernel_id, void * args[], KernelLaunchConfig config) {
281- if (kernel_id >= kernels.size ())
282- throw std::logic_error (" Kernel index out of range!" );
295+ int device_id; HIP_ERRCHK (hipGetDevice (&device_id));
296+ if (device_id != kernels->device ) {
297+ kernels.reset ();
298+ kernels.reset (new KernelLibrary (prog, kernel_binary, kernel_names));
299+ }
283300
284301 HIP_ERRCHK (
285- hipModuleLaunchKernel ( (kernels[kernel_id]),
302+ hipModuleLaunchKernel ( ((* kernels) [kernel_id]),
286303 config.num_blocks , 1 , 1 , // grid dim
287304 config.num_threads , 1 , 1 , // block dim
288305 config.smem , config.hStream , // shared mem and stream
@@ -291,10 +308,6 @@ class __attribute__((visibility("default"))) HIPJITKernel {
291308 }
292309
293310 ~HIPJITKernel () {
294- if (compiled) {
295- HIP_ERRCHK (hipModuleUnload (library));
296- delete[] code;
297- }
298311 HIPRTC_SAFE_CALL (hiprtcDestroyProgram (&prog));
299312 }
300313};
0 commit comments