@@ -307,38 +307,59 @@ struct KernelCode {
307307 }
308308
309309 /* *
310- * @brief Overload of the constructor to create a code object from a
311- * template string and workgroup size. Unlike the main factory function,
312- * this overload takes a single size_t workgroupSize parameter instead of a
313- * 3D shape for the workgroup size and instantiates a 3D shape with the
314- * workgroupSize in the x dimension and 1 in the y and z dimensions.
310+ * @brief Overload of the constructor to create a code object from a template
311+ * string and workgroup size. This overload takes a single size_t
312+ * workgroupSize parameter instead of a 3D shape for the workgroup size and
313+ * instantiates a 3D shape with the workgroupSize in the x dimension and 1 in
314+ * the y and z dimensions.
315+ *
316+ * @param[in] pData Shader template string with placeholders @param[in]
317+ * workgroupSize 3D Workgroup size
318+ * @param[in] precision Data type precision for the shader
319+ *
320+ * @code KernelCode code = {kPuzzle1, 256, kf32}; @endcode
321+ */
322+ inline KernelCode (const std::string &pData, const Shape &workgroupSize =
323+ {256 , 1 , 1 }, NumType precision = kf32) : data(pData),
324+ workgroupSize(workgroupSize), precision(precision) { if (precision == kf16) {
325+ data = " enable f16;\n " + data; } replaceAll (data, " {{workgroupSize}}" ,
326+ toString (workgroupSize)); replaceAll (data, " {{precision}}" ,
327+ toString (precision)); LOG (kDefLog , kInfo , " Shader code:\n %s" ,
328+ data.c_str ()); }
329+
330+
331+ /* *
332+ * @brief Overload of the constructor, adding totalWorkgroups parameter to
333+ * perform a string replacement for the total number of workgroups in the
334+ * kernel code.
315335 *
316336 * @param[in] pData Shader template string with placeholders
317- * @param[in] workgroupSize Workgroup size in the x dimension
337+ * @param[in] workgroupSize 3D Workgroup size
318338 * @param[in] precision Data type precision for the shader
339+ * @param[in] totalWorkgroups Total number of workgroups in the kernel
319340 *
320341 * @code
321- * KernelCode code = {kPuzzle1, 256, kf32};
342+ * KernelCode code = {kPuzzle1, { 256, 1, 1}, kf32, {2, 2, 1} };
322343 * @endcode
323344 */
324-
325345 inline KernelCode (const std::string &pData,
326- const Shape &workgroupSize = {256 , 1 , 1 },
327- NumType precision = kf32)
346+ const Shape &workgroupSize,
347+ NumType precision,
348+ const Shape &totalWorkgroups)
328349 : data(pData), workgroupSize(workgroupSize), precision(precision) {
329350 if (precision == kf16) {
330351 data = " enable f16;\n " + data;
331352 }
332353 replaceAll (data, " {{workgroupSize}}" , toString (workgroupSize));
333354 replaceAll (data, " {{precision}}" , toString (precision));
355+ replaceAll (data, " {{totalWorkgroups}}" , toString (totalWorkgroups));
334356 LOG (kDefLog , kInfo , " Shader code:\n %s" , data.c_str ());
335357 }
336358
337359
338360 /* *
339- * @brief Overload of the constructor, adding totalWorkgroups parameter to
340- * perform a string replacement for the total number of workgroups in the
341- * kernel code.
361+ * @brief Overload of the constructor, adding totalWorkgroups parameter as
362+ * well as the size_t 1D workgroupSize parameter.
342363 *
343364 * @param[in] pData Shader template string with placeholders
344365 * @param[in] workgroupSize Workgroup size in the x dimension
@@ -350,14 +371,14 @@ struct KernelCode {
350371 * @endcode
351372 */
352373 inline KernelCode (const std::string &pData,
353- const Shape &workgroupSize,
374+ const size_t &workgroupSize,
354375 NumType precision,
355376 const Shape &totalWorkgroups)
356- : data(pData), workgroupSize(workgroupSize), precision(precision) {
377+ : data(pData), workgroupSize({ workgroupSize, 1 , 1 } ), precision(precision) {
357378 if (precision == kf16) {
358379 data = " enable f16;\n " + data;
359380 }
360- replaceAll (data, " {{workgroupSize}}" , toString (workgroupSize));
381+ replaceAll (data, " {{workgroupSize}}" , toString ({ workgroupSize, 1 , 1 } ));
361382 replaceAll (data, " {{precision}}" , toString (precision));
362383 replaceAll (data, " {{totalWorkgroups}}" , toString (totalWorkgroups));
363384 LOG (kDefLog , kInfo , " Shader code:\n %s" , data.c_str ());
0 commit comments