Skip to content

Commit 95e587d

Browse files
committed
refactors the byIdx context function and sets USE_DAWN_API compile def on native
1 parent 9a08f8a commit 95e587d

4 files changed

Lines changed: 201 additions & 117 deletions

File tree

cmake/dawn.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ if(EMSCRIPTEN)
77
set(EM_SDK_DIR $ENV{EMSDK} CACHE INTERNAL "")
88
set(DAWN_BUILD_DIR "${DAWN_DIR}/build_web" CACHE INTERNAL "")
99
set(DAWN_EMSCRIPTEN_TOOLCHAIN ${EM_SDK_DIR}/upstream/emscripten CACHE INTERNAL "" FORCE)
10+
else()
11+
add_compile_definitions(USE_DAWN_API)
1012
endif()
1113

1214
# Enable find for no dawn rebuilds with flutter run

cmake/gpu.cmake

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ add_library(gpu STATIC ${GPU_SOURCES} ${GPU_HEADERS})
3232
set_target_properties(gpu PROPERTIES LINKER_LANGUAGE CXX)
3333
target_include_directories(gpu PUBLIC "${PROJECT_ROOT}")
3434
if(NOT EMSCRIPTEN)
35+
target_include_directories(gpu PUBLIC "${DAWN_BUILD_DIR}/gen/include/")
3536
target_include_directories(gpu PUBLIC "${DAWN_BUILD_DIR}/gen/include/dawn/")
37+
target_include_directories(gpu PUBLIC "${DAWN_DIR}/include/")
3638
else()
3739
target_include_directories(gpu PUBLIC "${DAWN_BUILD_DIR}/gen/src/emdawnwebgpu/include/")
3840
target_include_directories(gpu PUBLIC "${DAWN_BUILD_DIR}/gen/src/emdawnwebgpu/include/webgpu/")

examples/hello_world/run.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ int main(int argc, char **argv) {
2828
printf("--------------\n\n");
2929

3030
// std::unique_ptr<Context> ctx = createContext();
31+
#ifdef USE_DAWN_API
32+
Context ctx = createContextByGpuIdx(0);
33+
auto adaptersList = listAdapters(ctx);
34+
LOG(kDefLog, kInfo, "Available GPU adapters:\n%s", adaptersList.c_str());
35+
#else
3136
Context ctx = createContext();
37+
#endif
3238
static constexpr size_t N = 10000;
3339
std::array<float, N> inputArr, outputArr;
3440
for (int i = 0; i < N; ++i) {

gpu.hpp

Lines changed: 191 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
#include <utility> // std::pair
1717
#include <vector>
1818

19-
#ifndef __EMSCRIPTEN__
2019

21-
#else
20+
#ifdef __EMSCRIPTEN__
2221
#include "emscripten/emscripten.h"
2322
#endif
2423

@@ -255,6 +254,26 @@ inline std::string toString(const Shape &shape) {
255254
*/
256255
inline std::string toString(size_t value) { return std::to_string(value); }
257256

257+
/**
258+
* @brief Converts a WGPUStringView to an std::string.
259+
*
260+
* If the view's data is null, an empty string is returned. If the view's
261+
* length equals WGPU_STRLEN, it is assumed to be null‑terminated; otherwise,
262+
* the explicit length is used.
263+
*
264+
* @param strView The WGPUStringView to convert.
265+
* @return std::string The resulting standard string.
266+
*/
267+
inline std::string formatWGPUStringView(WGPUStringView strView) {
268+
if (!strView.data) {
269+
return "";
270+
}
271+
if (strView.length == WGPU_STRLEN) {
272+
return std::string(strView.data);
273+
}
274+
return std::string(strView.data, strView.length);
275+
}
276+
258277
/**
259278
* @brief simple in-place string replacement helper function for substituting
260279
* placeholders in a WGSL string template.
@@ -1076,136 +1095,191 @@ inline Context createContext(const WGPUInstanceDescriptor &desc = {},
10761095
return waitForContextFuture<Context>(contextFuture);
10771096
}
10781097

1079-
#ifdef USE_DAWN_API
1098+
#ifndef __EMSCRIPTEN__
1099+
#if USE_DAWN_API
10801100
/**
1081-
* @brief Factory function to create a GPU context, which aggregates WebGPU API
1082-
* handles to interact with the GPU including the instance, adapter, device, and
1083-
* queue.
1101+
* @brief Retrieves the list of available GPU adapters from the Dawn instance.
10841102
*
1085-
* The function takes gpu index to support for multi GPUs.
1086-
* To activate this function, it needs not only webgpu's headers but also DAWN's
1087-
* headers.
1103+
* This function creates a Dawn instance using the provided context's instance
1104+
* handle, then enumerates and returns the available GPU adapters as a vector.
10881105
*
1089-
* If dawn is used, it also sets up an error callback for device loss.
1106+
* @param ctx The Context containing the WebGPU instance handle.
1107+
* @return std::vector<dawn::native::Adapter> A vector of available GPU
1108+
* adapters.
1109+
*
1110+
* @code
1111+
* std::vector<dawn::native::Adapter> adapters = getAdapters(ctx);
1112+
* @endcode
1113+
*/
1114+
inline std::vector<dawn::native::Adapter> getAdapters(Context &ctx) {
1115+
dawn::native::Instance dawnInstance(
1116+
reinterpret_cast<dawn::native::InstanceBase *>(ctx.instance));
1117+
return dawnInstance.EnumerateAdapters();
1118+
}
1119+
1120+
/**
1121+
* @brief Formats the given vector of Dawn adapters into a single concatenated string.
10901122
*
1091-
* @param[in] gpuIdx GPU index
1092-
* @param[in] desc Instance descriptor for the WebGPU instance (optional)
1093-
* @param[in] devDescriptor Device descriptor for the WebGPU device (optional)
1094-
* @return Context instance representing the created GPU context
1123+
* This function iterates over each Dawn adapter in the provided vector, retrieves its
1124+
* description using the WebGPU API, and converts the description from a WGPUStringView
1125+
* to an std::string using the formatWGPUStringView helper. The resulting descriptions
1126+
* are concatenated into a single string separated by newline characters.
10951127
*
1128+
* @param adapters A vector of Dawn adapters obtained from a WebGPU instance.
1129+
* @return std::string A newline-delimited string listing each adapter's description.
1130+
*
10961131
* @code
1097-
* Context ctx = createContextByGpuIdx(1);
1132+
* std::string adapterList = formatAdapters(adapters);
10981133
* @endcode
10991134
*/
1100-
inline Context
1101-
createContextByGpuIdx(int gpuIdx, const WGPUInstanceDescriptor &desc = {},
1102-
const WGPUDeviceDescriptor &devDescriptor = {}) {
1103-
Context context;
1104-
{
1105-
#ifdef __EMSCRIPTEN__
1106-
// Emscripten does not support the instance descriptor
1107-
// and throws an assertion error if it is not nullptr.
1108-
context.instance = wgpuCreateInstance(nullptr);
1109-
#else
1110-
context.instance = wgpuCreateInstance(&desc);
1111-
#endif
1112-
// check status
1113-
check(context.instance, "Initialize WebGPU", __FILE__, __LINE__);
1135+
inline std::string formatAdapters(const std::vector<dawn::native::Adapter> &adapters) {
1136+
std::string adapterList;
1137+
for (size_t i = 0; i < adapters.size(); ++i) {
1138+
auto adapterPtr = adapters[i].Get();
1139+
if (adapterPtr) {
1140+
WGPUAdapterInfo info = {};
1141+
wgpuAdapterGetInfo(adapterPtr, &info);
1142+
std::string desc = formatWGPUStringView(info.description);
1143+
adapterList += "GPU Adapter [" + std::to_string(i) + "]: " + desc + "\n";
1144+
wgpuAdapterInfoFreeMembers(info);
1145+
}
11141146
}
1147+
return adapterList;
1148+
}
11151149

1116-
LOG(kDefLog, kInfo, "Requesting adapter");
1117-
{
1118-
std::vector<dawn::native::Adapter> adapters =
1119-
dawn::native::Instance(
1120-
reinterpret_cast<dawn::native::InstanceBase *>(context.instance))
1121-
.EnumerateAdapters();
1122-
LOG(kDefLog, kInfo, "The number of GPUs=%d\n", adapters.size());
1123-
// Note: Second gpu is not available on Macos, but the number of GPUs is 2
1124-
// on Macos.
1125-
// Calling wgpuAdapterGetInfo function for the second gpu becomes
1126-
// segfault. When you check all GPUs on linux, uncomment out following
1127-
// codes.
1128-
//
1129-
// for (size_t i = 0; i < adapters.size(); i++) {
1130-
// WGPUAdapterInfo info {};
1131-
// auto ptr = adapters[i].Get();
1132-
// if (ptr && adapters[i]) {
1133-
// wgpuAdapterGetInfo(ptr, &info);
1134-
// LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", i, info.description);
1135-
// wgpuAdapterInfoFreeMembers(info);
1136-
// }
1137-
// }
1138-
1139-
{
1140-
LOG(kDefLog, kInfo, "Use GPU(Adapter)[%d]\n", gpuIdx);
1141-
auto ptr = adapters[gpuIdx].Get();
1142-
if (ptr) {
1143-
WGPUAdapterInfo info{};
1144-
wgpuAdapterGetInfo(ptr, &info);
1145-
LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s\n", gpuIdx,
1146-
info.description);
1147-
wgpuAdapterInfoFreeMembers(info);
1148-
}
1149-
context.adapter = adapters[gpuIdx].Get();
1150-
dawn::native::GetProcs().adapterAddRef(context.adapter);
1151-
}
1150+
/**
1151+
* @brief Lists the available GPU adapters in the current WebGPU instance.
1152+
*
1153+
* This function retrieves the list of available GPU adapters using the
1154+
* getAdapters helper function, then formats and returns the adapter
1155+
* descriptions as a single string using the formatAdapters helper function.
1156+
*
1157+
* @param ctx The Context containing the WebGPU instance handle.
1158+
* @return std::string A newline-delimited string listing each adapter's
1159+
* description.
1160+
*
1161+
* @code
1162+
* std::string adapterList = listAdapters(ctx);
1163+
* @endcode
1164+
*/
1165+
inline std::string listAdapters(Context &ctx) {
1166+
auto adapters = getAdapters(ctx);
1167+
return formatAdapters(adapters);
1168+
}
1169+
1170+
/**
1171+
* @brief Asynchronously creates a GPU context using the specified GPU index.
1172+
*
1173+
* This function creates a WebGPU instance, retrieves the available GPU
1174+
* adapters, and selects the adapter at the specified index. It then requests a
1175+
* device from the selected adapter and sets up a logging callback for device
1176+
* errors. The function returns a future that will be fulfilled with the
1177+
* created Context once all operations are complete.
1178+
*
1179+
* @param gpuIdx The index of the GPU adapter to use.
1180+
* @param desc Instance descriptor for the WebGPU instance (optional)
1181+
* @param devDescriptor Device descriptor for the WebGPU device (optional)
1182+
* @return std::future<Context> A future that will eventually hold the created
1183+
* Context.
1184+
*
1185+
* @code
1186+
* std::future<Context> contextFuture = createContextByGpuIdxAsync(0);
1187+
* Context ctx = waitForContextFuture(contextFuture);
1188+
* @endcode
1189+
*/
1190+
inline std::future<Context>
1191+
createContextByGpuIdxAsync(int gpuIdx, const WGPUInstanceDescriptor &desc = {},
1192+
const WGPUDeviceDescriptor &devDescriptor = {}) {
1193+
auto promise = std::make_shared<std::promise<Context>>();
1194+
Context ctx;
1195+
1196+
ctx.instance = wgpuCreateInstance(&desc);
1197+
1198+
if (!ctx.instance) {
1199+
promise->set_exception(std::make_exception_ptr(
1200+
std::runtime_error("Failed to create WebGPU instance.")));
1201+
return promise->get_future();
11521202
}
1203+
check(ctx.instance, "Initialize WebGPU", __FILE__, __LINE__);
11531204

1154-
LOG(kDefLog, kInfo, "Requesting device");
1155-
{
1156-
struct DeviceData {
1157-
WGPUDevice device = nullptr;
1158-
bool requestEnded = false;
1159-
};
1160-
DeviceData devData;
1161-
1162-
auto onDeviceRequestEnded = [](WGPURequestDeviceStatus status,
1163-
WGPUDevice device, WGPUStringView message,
1164-
void *pUserData, void *) {
1165-
DeviceData &devData = *reinterpret_cast<DeviceData *>(pUserData);
1166-
check(status == WGPURequestDeviceStatus_Success,
1167-
"Could not get WebGPU device.", __FILE__, __LINE__);
1168-
LOG(kDefLog, kTrace, "Device Request succeeded %x",
1169-
static_cast<void *>(device));
1170-
devData.device = device;
1171-
devData.requestEnded = true;
1172-
};
1205+
// Use helper functions to obtain and format the adapters.
1206+
auto adapters = getAdapters(ctx);
11731207

1174-
WGPURequestDeviceCallbackInfo deviceCallbackInfo = {
1175-
.mode = WGPUCallbackMode_AllowSpontaneous,
1176-
.callback = onDeviceRequestEnded,
1177-
.userdata1 = &devData,
1178-
.userdata2 = nullptr};
1179-
wgpuAdapterRequestDevice(context.adapter, &devDescriptor,
1180-
deviceCallbackInfo);
1181-
1182-
LOG(kDefLog, kInfo, "Waiting for device request to end");
1183-
while (!devData.requestEnded) {
1184-
processEvents(context.instance);
1185-
}
1186-
LOG(kDefLog, kInfo, "Device request ended");
1187-
assert(devData.requestEnded);
1188-
context.device = devData.device;
1189-
1190-
WGPULoggingCallbackInfo loggingCallbackInfo = {
1191-
.nextInChain = nullptr,
1192-
.callback =
1193-
[](WGPULoggingType type, WGPUStringView message, void *userdata1,
1194-
void *userdata2) {
1195-
LOG(kDefLog, kError, "Device logging callback: %.*s",
1196-
static_cast<int>(message.length), message.data);
1197-
if (type == WGPULoggingType_Error) {
1198-
throw std::runtime_error("Device error logged.");
1199-
}
1200-
},
1201-
.userdata1 = nullptr,
1202-
.userdata2 = nullptr};
1203-
wgpuDeviceSetLoggingCallback(context.device, loggingCallbackInfo);
1208+
if (gpuIdx >= adapters.size()) {
1209+
promise->set_exception(
1210+
std::make_exception_ptr(std::runtime_error("Invalid GPU index.")));
1211+
return promise->get_future();
1212+
}
1213+
LOG(kDefLog, kInfo, "Using GPU Adapter[%d]", gpuIdx);
1214+
auto adapterPtr = adapters[gpuIdx].Get();
1215+
if (adapterPtr) {
1216+
WGPUAdapterInfo info = {};
1217+
wgpuAdapterGetInfo(adapterPtr, &info);
1218+
LOG(kDefLog, kInfo, "GPU(Adapter)[%d] = %s", gpuIdx,
1219+
formatWGPUStringView(info.description).c_str());
1220+
wgpuAdapterInfoFreeMembers(info);
1221+
}
1222+
ctx.adapter = reinterpret_cast<WGPUAdapter>(adapterPtr);
1223+
dawn::native::GetProcs().adapterAddRef(ctx.adapter);
1224+
1225+
LOG(kDefLog, kInfo, "Requesting device");
1226+
// Request the device asynchronously (using our requestDeviceAsync helper).
1227+
auto deviceFuture = requestDeviceAsync(ctx.adapter, devDescriptor);
1228+
try {
1229+
ctx.device = wait(ctx, deviceFuture);
1230+
ctx.deviceStatus = WGPURequestDeviceStatus_Success;
1231+
} catch (const std::exception &ex) {
1232+
promise->set_exception(std::make_exception_ptr(ex));
1233+
return promise->get_future();
12041234
}
1205-
context.queue = wgpuDeviceGetQueue(context.device);
1206-
return context;
1235+
1236+
WGPULoggingCallbackInfo loggingCallbackInfo{
1237+
.nextInChain = nullptr,
1238+
.callback =
1239+
[](WGPULoggingType type, WGPUStringView message, void *userdata1,
1240+
void *userdata2) {
1241+
LOG(kDefLog, kError, "Device logging callback: %.*s",
1242+
static_cast<int>(message.length), message.data);
1243+
if (type == WGPULoggingType_Error) {
1244+
throw std::runtime_error("Device error logged.");
1245+
}
1246+
},
1247+
.userdata1 = nullptr,
1248+
.userdata2 = nullptr};
1249+
wgpuDeviceSetLoggingCallback(ctx.device, loggingCallbackInfo);
1250+
ctx.queue = wgpuDeviceGetQueue(ctx.device);
1251+
promise->set_value(std::move(ctx));
1252+
return promise->get_future();
12071253
}
1208-
#endif
1254+
1255+
/**
1256+
* @brief Synchronously creates a GPU context using the specified GPU index.
1257+
*
1258+
* This function calls the asynchronous createContextByGpuIdxAsync function to
1259+
* create a GPU context, then waits for its completion using
1260+
* waitForContextFuture. The returned Context holds handles to the WebGPU
1261+
* instance, adapter, device, and queue, and is used for subsequent GPU
1262+
* operations.
1263+
*
1264+
* @param gpuIdx The index of the GPU adapter to use.
1265+
* @param desc Instance descriptor for the WebGPU instance (optional)
1266+
* @param devDescriptor Device descriptor for the WebGPU device (optional)
1267+
* @return Context The fully initialized GPU context.
1268+
*
1269+
* @code
1270+
* Context ctx = createContextByGpuIdx(0);
1271+
* @endcode
1272+
*/
1273+
inline Context createContextByGpuIdx(int gpuIdx,
1274+
const WGPUInstanceDescriptor &desc = {},
1275+
const WGPUDeviceDescriptor &devDescriptor = {}) {
1276+
std::future<Context> contextFuture =
1277+
createContextByGpuIdxAsync(gpuIdx, desc, devDescriptor);
1278+
return waitForContextFuture<Context>(contextFuture);
1279+
}
1280+
1281+
#endif // USE_DAWN_API
1282+
#endif // __EMSCRIPTEN__
12091283

12101284
/**
12111285
* @brief Callback function invoked upon completion of an asynchronous GPU

0 commit comments

Comments
 (0)