Skip to content

Commit 0c74183

Browse files
Add createContextByGpuIdx to support for multi GPUs
1 parent d4f96b8 commit 0c74183

1 file changed

Lines changed: 117 additions & 0 deletions

File tree

gpu.h

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
#include "emscripten/emscripten.h"
2525
#endif
2626

27+
#ifdef USE_DAWN_API
28+
#include "dawn/native/DawnNative.h"
29+
30+
typedef WGPUBufferUsage WGPUBufferUsageFlags;
31+
#endif
32+
2733
namespace gpu {
2834

2935
/**
@@ -785,6 +791,117 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {},
785791
return context;
786792
}
787793

794+
795+
#ifdef USE_DAWN_API
796+
/**
797+
* @brief Factory function to create a GPU context, which aggregates WebGPU API
798+
* handles to interact with the GPU including the instance, adapter, device, and
799+
* queue.
800+
*
801+
* The function takes gpu index to support for multi GPUs.
802+
* To activate this function, it needs not only webgpu's headers but also DAWN's headers.
803+
*
804+
* If dawn is used, it also sets up an error callback for device loss.
805+
*
806+
* @param[in] gpuIdx GPU index
807+
* @param[in] desc Instance descriptor for the WebGPU instance (optional)
808+
* @param[in] devDescriptor Device descriptor for the WebGPU device (optional)
809+
* @return Context instance representing the created GPU context
810+
*
811+
* @code
812+
* Context ctx = createContextByGpuIdx(1);
813+
* @endcode
814+
*/
815+
inline Context createContextByGpuIdx(int gpuIdx,
816+
const WGPUInstanceDescriptor &desc = {},
817+
const WGPUDeviceDescriptor &devDescriptor = {}) {
818+
Context context;
819+
{
820+
#ifdef __EMSCRIPTEN__
821+
// Emscripten does not support the instance descriptor
822+
// and throws an assertion error if it is not nullptr.
823+
context.instance = wgpuCreateInstance(nullptr);
824+
#else
825+
context.instance = wgpuCreateInstance(&desc);
826+
#endif
827+
// check status
828+
check(context.instance, "Initialize WebGPU", __FILE__, __LINE__);
829+
}
830+
831+
LOG(kDefLog, kInfo, "Requesting adapter");
832+
{
833+
std::vector<dawn::native::Adapter> adapters =
834+
dawn::native::Instance(reinterpret_cast<dawn::native::InstanceBase*>(context.instance))
835+
.EnumerateAdapters();
836+
LOG(kDefLog, kInfo, "The number of GPUs=%d\n", adapters.size());
837+
// Note: Second gpu is not available on Macos, but the number of GPUs is 2 on Macos.
838+
// Calling wgpuAdapterGetInfo function for the second gpu becomes segfault.
839+
// When you check all GPUs on linux, uncomment out following codes.
840+
//
841+
// for (size_t i = 0; i < adapters.size(); i++) {
842+
// WGPUAdapterInfo info {};
843+
// auto ptr = adapters[i].Get();
844+
// if (ptr && adapters[i]) {
845+
// wgpuAdapterGetInfo(ptr, &info);
846+
// LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", i, info.description);
847+
// wgpuAdapterInfoFreeMembers(info);
848+
// }
849+
// }
850+
851+
{
852+
LOG(kDefLog, kInfo, "Use GPU(Adapter)[%d]\n", gpuIdx);
853+
auto ptr = adapters[gpuIdx].Get();
854+
if (ptr) {
855+
WGPUAdapterInfo info {};
856+
wgpuAdapterGetInfo(ptr, &info);
857+
LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", gpuIdx, info.description);
858+
wgpuAdapterInfoFreeMembers(info);
859+
}
860+
context.adapter = adapters[gpuIdx].Get();
861+
dawn::native::GetProcs().adapterAddRef(context.adapter);
862+
}
863+
}
864+
865+
LOG(kDefLog, kInfo, "Requesting device");
866+
{
867+
struct DeviceData {
868+
WGPUDevice device = nullptr;
869+
bool requestEnded = false;
870+
};
871+
DeviceData devData;
872+
auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status,
873+
WGPUDevice device, char const *message,
874+
void *pUserData) {
875+
DeviceData &devData = *reinterpret_cast<DeviceData *>(pUserData);
876+
check(status == WGPURequestDeviceStatus_Success,
877+
"Could not get WebGPU device.", __FILE__, __LINE__);
878+
LOG(kDefLog, kTrace, "Device Request succeeded %x",
879+
static_cast<void *>(device));
880+
devData.device = device;
881+
devData.requestEnded = true;
882+
};
883+
wgpuAdapterRequestDevice(context.adapter, &devDescriptor,
884+
onDeviceRequestEnded, (void *)&devData);
885+
LOG(kDefLog, kInfo, "Waiting for device request to end");
886+
while (!devData.requestEnded) {
887+
processEvents(context.instance);
888+
}
889+
LOG(kDefLog, kInfo, "Device request ended");
890+
assert(devData.requestEnded);
891+
context.device = devData.device;
892+
wgpuDeviceSetUncapturedErrorCallback(
893+
context.device,
894+
[](WGPUErrorType type, char const *message, void *devData) {
895+
LOG(kDefLog, kError, "Device uncaptured error: %s", message);
896+
throw std::runtime_error("Device uncaptured exception.");
897+
},
898+
nullptr);
899+
}
900+
context.queue = wgpuDeviceGetQueue(context.device);
901+
return context;
902+
}
903+
#endif
904+
788905
inline void wait(Context &ctx, std::future<void> &future) {
789906
while (future.wait_for(std::chrono::seconds(0)) !=
790907
std::future_status::ready) {

0 commit comments

Comments
 (0)