1111
1212using namespace gpu ;
1313
14+ constexpr size_t kN = 100 ;
15+
1416EM_JS (void , js_print, (const char *str), {
1517 if (typeof window != ' undefined' && window.customPrint ) {
1618 window.customPrint (UTF8ToString (str));
@@ -20,60 +22,76 @@ EM_JS(void, js_print, (const char *str), {
2022 }
2123});
2224
23- constexpr size_t kN = 5000 ;
24-
25- extern " C" {
26-
27- EMSCRIPTEN_KEEPALIVE bool checkAnswer (std::array<float , kN > &outputArr) {
28- return outputArr[0 ] == 10 ;
29- // return false;
30- }
31-
32- EMSCRIPTEN_KEEPALIVE
33- void executeKernel (Context& ctx, const char *kernelCode, const Shape &wgSize,
34- const Shape &nWorkgroups,
35- std::array<float , kN > &outputArr) {
25+ template <size_t nInputs>
26+ struct HostSpec {
27+ const Shape wgSize;
28+ const Shape nWorkgroups;
29+ const std::string kernelCode;
30+ std::array<std::vector<float >, nInputs> inputs;
31+ };
3632
37- // TODO(avh): use puzzle dispatch from scaffold.h for host implementation
38- char buffer[1024 ]; // for printing
39- constexpr size_t N = 5000 ;
40- std::array<float , N> inputArr;
41- for (int i = 0 ; i < N; ++i) {
42- inputArr[i] = static_cast <float >(i);
33+ template <size_t nInputs>
34+ void executeKernel (Context& ctx,
35+ const HostSpec<nInputs>& spec,
36+ float * outputPtr, size_t outputSize) {
37+ std::array<Tensor, nInputs + 1 > bindingsArr; // + 1 for output binding
38+ for (size_t inputIndex = 0 ; inputIndex < nInputs; ++inputIndex) {
39+ bindingsArr[inputIndex] = createTensor (ctx, Shape{spec.inputs [inputIndex].size ()}, kf32, spec.inputs [inputIndex].data ());
4340 }
44- Tensor input = createTensor (ctx, Shape{N}, kf32, inputArr.data ());
45- Tensor output = createTensor (ctx, Shape{N}, kf32);
41+ Tensor output = createTensor (ctx, Shape{outputSize}, kf32);
42+ bindingsArr[nInputs] = output;
43+ Bindings bindings{bindingsArr};
4644 std::promise<void > promise;
4745 std::future<void > future = promise.get_future ();
48- Kernel op = createKernel (ctx, {kernelCode, wgSize, kf32},
49- Bindings{input, output}, nWorkgroups);
50-
46+ Kernel op = createKernel (ctx, {spec.kernelCode , spec.wgSize , kf32},
47+ bindings, spec.nWorkgroups );
5148 dispatchKernel (ctx, op, promise);
5249 wait (ctx, future);
53- toCPU (ctx, output, outputArr.data (), sizeof (outputArr));
54- for (int i = 0 ; i < 10 ; ++i) {
55- snprintf (buffer, sizeof (buffer), " [%d] kernel(%.1f) = %.4f" , i,
56- inputArr[i], outputArr[i]);
57- js_print (buffer);
50+ toCPU (ctx, output, outputPtr, outputSize * sizeof (float ));
51+ }
52+
53+ extern " C" {
54+
55+ void generatePreamble (size_t nInputs, Shape& wgSize, Shape& nWorkgroups, const char * out, size_t outSize) {
56+ std::string result = " " ;
57+ for (size_t i = 0 ; i < nInputs; ++i) {
58+ result += " @group(0) @binding(" + std::to_string (i) + " ) var input" + std::to_string (i) + " : array;\n " ;
5859 }
59- js_print (" ..." );
60- for (int i = N - 10 ; i < N; ++i) {
61- snprintf (buffer, sizeof (buffer), " [%d] kernel(%.1f) = %.4f" , i,
62- inputArr[i], outputArr[i]);
63- js_print (buffer);
60+ result += " @group(0) @binding(" + std::to_string (nInputs) + " ) var output : array;\n " ;
61+ result += " @compute @workgroup_size(" + std::to_string (wgSize[0 ]) + " , " + std::to_string (wgSize[1 ]) + " , " + std::to_string (wgSize[2 ]) + " )\n " ;
62+ std::strncpy (const_cast <char *>(out), result.c_str (), outSize);
63+ }
64+
65+
66+ EMSCRIPTEN_KEEPALIVE
67+ void runCheck (const char *kernelCode, const Shape &wgSize,
68+ const Shape &nWorkgroups) {
69+ Context ctx = createContext ({});
70+ std::array<float , kN > output;
71+ std::vector<float > input (N);
72+ for (int i = 0 ; i < kN ; ++i) {
73+ input[i] = static_cast <float >(i);
6474 }
65- snprintf (buffer, sizeof (buffer), " Computed %zu values" , N);
66- js_print (buffer);
67- } // executeKernel
75+ HostSpec<1 > spec = {
76+ wgSize,
77+ nWorkgroups,
78+ kernelCode,
79+ std::array<std::vector<float >, 1 > {input}
80+ };
81+ executeKernel<1 >(ctx, spec, output.data (), kN );
82+ }
6883
6984EMSCRIPTEN_KEEPALIVE
70- bool runCheck (const char *kernelCode, const Shape &wgSize,
85+ bool evaluate (const char *kernelCode, const Shape &wgSize,
7186 const Shape &nWorkgroups) {
87+ char buffer[1024 ]; // for printing
88+
89+ snprintf (buffer, sizeof (buffer), " Evaluating kernel with workgroup size (%zu, %zu, %zu) and nWorkgroups (%zu, %zu, %zu)" ,
90+ wgSize[0 ], wgSize[1 ], wgSize[2 ], nWorkgroups[0 ], nWorkgroups[1 ], nWorkgroups[2 ]);
91+ js_print (buffer);
7292 Context ctx = createContext ({});
73- std::array<float , kN > outputArr;
74- executeKernel (ctx, kernelCode, wgSize, nWorkgroups, outputArr);
7593 TestCases testCases = createTestCases ();
76- return evaluate (ctx, testCases, std::string ( kernelCode) , 0 );
94+ return evaluate (ctx, testCases, kernelCode, 0 );
7795}
7896
7997} // extern "C"
@@ -89,20 +107,19 @@ EMSCRIPTEN_BINDINGS(module) {
89107 emscripten::register_vector<std::vector<float >>(" VectorFloat" );
90108 emscripten::register_vector<std::vector<int >>(" VectorInt" );
91109
110+
92111 emscripten::function (
93- " runCheck " ,
112+ " evaluate " ,
94113 emscripten::optional_override (
95114 [](const std::string &kernelCode, const std::array<size_t , 3 > &wgSize,
96115 const std::array<size_t , 3 > &nWorkgroups) {
97- return runCheck (kernelCode.c_str (),
116+ return evaluate (kernelCode.c_str (),
98117 Shape{static_cast <size_t >(wgSize[0 ]),
99118 static_cast <size_t >(wgSize[1 ]),
100119 static_cast <size_t >(wgSize[2 ])},
101120 Shape{static_cast <size_t >(nWorkgroups[0 ]),
102121 static_cast <size_t >(nWorkgroups[1 ]),
103122 static_cast <size_t >(nWorkgroups[2 ])});
104123 }));
105-
106- emscripten::function (" checkAnswer" , &checkAnswer);
107124}
108125#endif
0 commit comments