Skip to content

Commit 218ee99

Browse files
authored
Merge pull request #50 from junjihashimoto/feature/multi-gpu
Add createContextByGpuIdx to support for multi GPUs
2 parents 16a7b4b + 0c74183 commit 218ee99

1 file changed

Lines changed: 117 additions & 0 deletions

File tree

gpu.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
#include "emscripten/emscripten.h"
2525
#endif
2626

27+
#ifdef USE_DAWN_API
28+
#include "dawn/native/DawnNative.h"
29+
30+
typedef WGPUBufferUsage WGPUBufferUsageFlags;
31+
#endif
32+
2733
namespace gpu {
2834

2935
/**
@@ -846,6 +852,117 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {},
846852
return context;
847853
}
848854

855+
856+
#ifdef USE_DAWN_API
857+
/**
858+
* @brief Factory function to create a GPU context, which aggregates WebGPU API
859+
* handles to interact with the GPU including the instance, adapter, device, and
860+
* queue.
861+
*
862+
* The function takes gpu index to support for multi GPUs.
863+
* To activate this function, it needs not only webgpu's headers but also DAWN's headers.
864+
*
865+
* If dawn is used, it also sets up an error callback for device loss.
866+
*
867+
* @param[in] gpuIdx GPU index
868+
* @param[in] desc Instance descriptor for the WebGPU instance (optional)
869+
* @param[in] devDescriptor Device descriptor for the WebGPU device (optional)
870+
* @return Context instance representing the created GPU context
871+
*
872+
* @code
873+
* Context ctx = createContextByGpuIdx(1);
874+
* @endcode
875+
*/
876+
inline Context createContextByGpuIdx(int gpuIdx,
877+
const WGPUInstanceDescriptor &desc = {},
878+
const WGPUDeviceDescriptor &devDescriptor = {}) {
879+
Context context;
880+
{
881+
#ifdef __EMSCRIPTEN__
882+
// Emscripten does not support the instance descriptor
883+
// and throws an assertion error if it is not nullptr.
884+
context.instance = wgpuCreateInstance(nullptr);
885+
#else
886+
context.instance = wgpuCreateInstance(&desc);
887+
#endif
888+
// check status
889+
check(context.instance, "Initialize WebGPU", __FILE__, __LINE__);
890+
}
891+
892+
LOG(kDefLog, kInfo, "Requesting adapter");
893+
{
894+
std::vector<dawn::native::Adapter> adapters =
895+
dawn::native::Instance(reinterpret_cast<dawn::native::InstanceBase*>(context.instance))
896+
.EnumerateAdapters();
897+
LOG(kDefLog, kInfo, "The number of GPUs=%d\n", adapters.size());
898+
// Note: Second gpu is not available on Macos, but the number of GPUs is 2 on Macos.
899+
// Calling wgpuAdapterGetInfo function for the second gpu becomes segfault.
900+
// When you check all GPUs on linux, uncomment out following codes.
901+
//
902+
// for (size_t i = 0; i < adapters.size(); i++) {
903+
// WGPUAdapterInfo info {};
904+
// auto ptr = adapters[i].Get();
905+
// if (ptr && adapters[i]) {
906+
// wgpuAdapterGetInfo(ptr, &info);
907+
// LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", i, info.description);
908+
// wgpuAdapterInfoFreeMembers(info);
909+
// }
910+
// }
911+
912+
{
913+
LOG(kDefLog, kInfo, "Use GPU(Adapter)[%d]\n", gpuIdx);
914+
auto ptr = adapters[gpuIdx].Get();
915+
if (ptr) {
916+
WGPUAdapterInfo info {};
917+
wgpuAdapterGetInfo(ptr, &info);
918+
LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", gpuIdx, info.description);
919+
wgpuAdapterInfoFreeMembers(info);
920+
}
921+
context.adapter = adapters[gpuIdx].Get();
922+
dawn::native::GetProcs().adapterAddRef(context.adapter);
923+
}
924+
}
925+
926+
LOG(kDefLog, kInfo, "Requesting device");
927+
{
928+
struct DeviceData {
929+
WGPUDevice device = nullptr;
930+
bool requestEnded = false;
931+
};
932+
DeviceData devData;
933+
auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status,
934+
WGPUDevice device, char const *message,
935+
void *pUserData) {
936+
DeviceData &devData = *reinterpret_cast<DeviceData *>(pUserData);
937+
check(status == WGPURequestDeviceStatus_Success,
938+
"Could not get WebGPU device.", __FILE__, __LINE__);
939+
LOG(kDefLog, kTrace, "Device Request succeeded %x",
940+
static_cast<void *>(device));
941+
devData.device = device;
942+
devData.requestEnded = true;
943+
};
944+
wgpuAdapterRequestDevice(context.adapter, &devDescriptor,
945+
onDeviceRequestEnded, (void *)&devData);
946+
LOG(kDefLog, kInfo, "Waiting for device request to end");
947+
while (!devData.requestEnded) {
948+
processEvents(context.instance);
949+
}
950+
LOG(kDefLog, kInfo, "Device request ended");
951+
assert(devData.requestEnded);
952+
context.device = devData.device;
953+
wgpuDeviceSetUncapturedErrorCallback(
954+
context.device,
955+
[](WGPUErrorType type, char const *message, void *devData) {
956+
LOG(kDefLog, kError, "Device uncaptured error: %s", message);
957+
throw std::runtime_error("Device uncaptured exception.");
958+
},
959+
nullptr);
960+
}
961+
context.queue = wgpuDeviceGetQueue(context.device);
962+
return context;
963+
}
964+
#endif
965+
849966
inline void wait(Context &ctx, std::future<void> &future) {
850967
while (future.wait_for(std::chrono::seconds(0)) !=
851968
std::future_status::ready) {

0 commit comments

Comments
 (0)