Skip to content

Commit e907a00

Browse files
committed
1D workgroup specification overload for KernelCode constructor accepting totalWorkgroups, add webprint.h to gpu puzzles project
1 parent fce7022 commit e907a00

2 files changed

Lines changed: 93 additions & 16 deletions

File tree

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#ifndef WEBPRINT_H
2+
#define WEBPRINT_H
3+
4+
5+
#include <emscripten/emscripten.h>
6+
7+
EM_JS(void, js_print, (const char *str), {
8+
if (typeof window != 'undefined' && window.customPrint) {
9+
window.customPrint(UTF8ToString(str));
10+
} else {
11+
console.log("window.customPrint is not defined.");
12+
console.log(UTF8ToString(str));
13+
}
14+
});
15+
16+
17+
// need to allow printf with variable arguments
18+
#pragma clang diagnostic push
19+
#pragma clang diagnostic ignored "-Wformat-security"
20+
template<typename... Args>
21+
void wprintf(const char *str, Args... args) {
22+
char buffer[1024];
23+
snprintf(buffer, sizeof(buffer), str, args...);
24+
js_print(buffer);
25+
}
26+
#pragma clang diagnostic pop
27+
28+
void printVec(const std::vector<float> &vec, const char *name = "") {
29+
char buffer[1024];
30+
size_t pos = 0;
31+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "[ ");
32+
for (size_t i = 0; i < vec.size(); ++i) {
33+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%.1f", vec[i]);
34+
if (i != vec.size() - 1) {
35+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, ", ");
36+
}
37+
}
38+
snprintf(buffer + pos, sizeof(buffer) - pos, " ]");
39+
wprintf("%s %s", name, buffer);
40+
}
41+
42+
43+
void printVecBuf(const std::vector<float> &vec, const char *name, char *buffer, size_t& pos) {
44+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%s", name);
45+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "[ ");
46+
for (size_t i = 0; i < vec.size(); ++i) {
47+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, "%2.0f", vec[i]);
48+
if (i != vec.size() - 1) {
49+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, ", ");
50+
}
51+
}
52+
pos += snprintf(buffer + pos, sizeof(buffer) - pos, " ]\n\r");
53+
}
54+
55+
56+
#endif // WEBPRINT_H

gpu.h

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -307,38 +307,59 @@ struct KernelCode {
307307
}
308308

309309
/**
310-
* @brief Overload of the constructor to create a code object from a
311-
* template string and workgroup size. Unlike the main factory function,
312-
* this overload takes a single size_t workgroupSize parameter instead of a
313-
* 3D shape for the workgroup size and instantiates a 3D shape with the
314-
* workgroupSize in the x dimension and 1 in the y and z dimensions.
310+
* @brief Overload of the constructor to create a code object from a template
311+
* string and workgroup size. This overload takes a single size_t
312+
* workgroupSize parameter instead of a 3D shape for the workgroup size and
313+
* instantiates a 3D shape with the workgroupSize in the x dimension and 1 in
314+
* the y and z dimensions.
315+
*
316+
* @param[in] pData Shader template string with placeholders @param[in]
317+
* workgroupSize 3D Workgroup size
318+
* @param[in] precision Data type precision for the shader
319+
*
320+
* @code KernelCode code = {kPuzzle1, 256, kf32}; @endcode
321+
*/
322+
inline KernelCode(const std::string &pData, const Shape &workgroupSize =
323+
{256, 1, 1}, NumType precision = kf32) : data(pData),
324+
workgroupSize(workgroupSize), precision(precision) { if (precision == kf16) {
325+
data = "enable f16;\n" + data; } replaceAll(data, "{{workgroupSize}}",
326+
toString(workgroupSize)); replaceAll(data, "{{precision}}",
327+
toString(precision)); LOG(kDefLog, kInfo, "Shader code:\n%s",
328+
data.c_str()); }
329+
330+
331+
/**
332+
* @brief Overload of the constructor, adding totalWorkgroups parameter to
333+
* perform a string replacement for the total number of workgroups in the
334+
* kernel code.
315335
*
316336
* @param[in] pData Shader template string with placeholders
317-
* @param[in] workgroupSize Workgroup size in the x dimension
337+
* @param[in] workgroupSize 3D Workgroup size
318338
* @param[in] precision Data type precision for the shader
339+
* @param[in] totalWorkgroups Total number of workgroups in the kernel
319340
*
320341
* @code
321-
* KernelCode code = {kPuzzle1, 256, kf32};
342+
* KernelCode code = {kPuzzle1, {256, 1, 1}, kf32, {2, 2, 1}};
322343
* @endcode
323344
*/
324-
325345
inline KernelCode(const std::string &pData,
326-
const Shape &workgroupSize = {256, 1, 1},
327-
NumType precision = kf32)
346+
const Shape &workgroupSize,
347+
NumType precision,
348+
const Shape &totalWorkgroups)
328349
: data(pData), workgroupSize(workgroupSize), precision(precision) {
329350
if (precision == kf16) {
330351
data = "enable f16;\n" + data;
331352
}
332353
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
333354
replaceAll(data, "{{precision}}", toString(precision));
355+
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
334356
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
335357
}
336358

337359

338360
/**
339-
* @brief Overload of the constructor, adding totalWorkgroups parameter to
340-
* perform a string replacement for the total number of workgroups in the
341-
* kernel code.
361+
* @brief Overload of the constructor, adding totalWorkgroups parameter as
362+
* well as the size_t 1D workgroupSize parameter.
342363
*
343364
* @param[in] pData Shader template string with placeholders
344365
* @param[in] workgroupSize Workgroup size in the x dimension
@@ -350,14 +371,14 @@ struct KernelCode {
350371
* @endcode
351372
*/
352373
inline KernelCode(const std::string &pData,
353-
const Shape &workgroupSize,
374+
const size_t &workgroupSize,
354375
NumType precision,
355376
const Shape &totalWorkgroups)
356-
: data(pData), workgroupSize(workgroupSize), precision(precision) {
377+
: data(pData), workgroupSize({workgroupSize, 1, 1}), precision(precision) {
357378
if (precision == kf16) {
358379
data = "enable f16;\n" + data;
359380
}
360-
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
381+
replaceAll(data, "{{workgroupSize}}", toString({workgroupSize, 1, 1}));
361382
replaceAll(data, "{{precision}}", toString(precision));
362383
replaceAll(data, "{{totalWorkgroups}}", toString(totalWorkgroups));
363384
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());

0 commit comments

Comments
 (0)