From 472d64c75a5eb5230a4c972c133574049fd9e14a Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 4 Jan 2024 13:46:51 -0800 Subject: [PATCH 01/99] basic infrastructure, dpas version is working --- include/bfloat16.hpp | 125 +++++++++ include/util.hpp | 23 ++ samples/99_matrixexperiments/CMakeLists.txt | 11 + samples/99_matrixexperiments/main.cpp | 264 ++++++++++++++++++ .../99_matrixexperiments/matrix_kernels.cl | 86 ++++++ samples/CMakeLists.txt | 2 + 6 files changed, 511 insertions(+) create mode 100644 include/bfloat16.hpp create mode 100644 samples/99_matrixexperiments/CMakeLists.txt create mode 100644 samples/99_matrixexperiments/main.cpp create mode 100644 samples/99_matrixexperiments/matrix_kernels.cl diff --git a/include/bfloat16.hpp b/include/bfloat16.hpp new file mode 100644 index 00000000..5e9541bd --- /dev/null +++ b/include/bfloat16.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include +#include + +class bfloat16; + +class bfloat16 { + using StorageType = uint16_t; + StorageType value; + + static StorageType from_float(const float &a) { + if (std::isnan(a)) + return 0xffc1; + union { + uint32_t intStorage; + float floatValue; + }; + floatValue = a; + // Do RNE and truncate + uint32_t roundingBias = ((intStorage >> 16) & 0x1) + 0x00007FFF; + return static_cast((intStorage + roundingBias) >> 16); + } + + static float to_float(const StorageType &a) { + union { + uint32_t intStorage; + float floatValue; + }; + intStorage = a << 16; + return floatValue; + } + +public: + bfloat16() = default; + bfloat16(const bfloat16 &) = default; + ~bfloat16() = default; + + // Implicit conversion from float to bfloat16 + bfloat16(const float &a) { value = from_float(a); } + + bfloat16 &operator=(const float &rhs) { + value = from_float(rhs); + return *this; + } + + // Implicit conversion from bfloat16 to float + operator float() const { return to_float(value); } + + // Logical operators (!,||,&&) are covered if we can cast to bool + explicit operator bool() { return to_float(value) != 0.0f; } + + // Unary minus operator overloading + friend bfloat16 operator-(bfloat16 &lhs) { + return -to_float(lhs.value); + } + + // Increment and decrement operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs) { \ + float f = to_float(lhs.value); \ + lhs.value = from_float(op f); \ + return lhs; \ + } \ + friend bfloat16 operator op(bfloat16 &lhs, int) { \ + bfloat16 old = lhs; \ + operator op(lhs); \ + return old; \ + } + OP(++) + OP(--) +#undef OP + + // Assignment operators overloading +#define OP(op) \ + friend bfloat16 &operator op(bfloat16 &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template \ + friend bfloat16 &operator op(bfloat16 &lhs, const T &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } \ + template friend T &operator op(T &lhs, const bfloat16 &rhs) { \ + float f = static_cast(lhs); \ + f op static_cast(rhs); \ + return lhs = f; \ + } + OP(+=) + OP(-=) + OP(*=) + OP(/=) +#undef OP + +// Binary operators overloading +#define OP(type, op) \ + friend type operator op(const bfloat16 &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const bfloat16 &lhs, const T &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } \ + template \ + friend type operator op(const T &lhs, const bfloat16 &rhs) { \ + return type{static_cast(lhs) op static_cast(rhs)}; \ + } + OP(bfloat16, +) + OP(bfloat16, -) + OP(bfloat16, *) + OP(bfloat16, /) + OP(bool, ==) + OP(bool, !=) + OP(bool, <) + OP(bool, >) + OP(bool, <=) + OP(bool, >=) +#undef OP + + // Bitwise(|,&,~,^), modulo(%) and shift(<<,>>) operations are not supported + // for floating-point types. +}; \ No newline at end of file diff --git a/include/util.hpp b/include/util.hpp index eb8a99e1..aaf28ce3 100644 --- a/include/util.hpp +++ b/include/util.hpp @@ -6,6 +6,8 @@ #pragma once #include + +#include #include static cl_uint getDeviceOpenCLVersion( @@ -67,3 +69,24 @@ static bool checkDeviceForExtension( return supported; } + +static std::string readStringFromFile( + const std::string& filename ) +{ + std::ifstream is(filename, std::ios::binary); + if (!is.good()) { + printf("Couldn't open file '%s'!\n", filename.c_str()); + return ""; + } + + size_t filesize = 0; + is.seekg(0, std::ios::end); + filesize = (size_t)is.tellg(); + is.seekg(0, std::ios::beg); + + std::string source{ + std::istreambuf_iterator(is), + std::istreambuf_iterator() }; + + return source; +} diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt new file mode 100644 index 00000000..7f2a696d --- /dev/null +++ b/samples/99_matrixexperiments/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2019-2024 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 05 + TARGET matrixexperiments + VERSION 120 + SOURCES main.cpp + KERNELS matrix_kernels.cl) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp new file mode 100644 index 00000000..60420738 --- /dev/null +++ b/samples/99_matrixexperiments/main.cpp @@ -0,0 +1,264 @@ +/* +// Copyright (c) 2019-2024 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include + +#include "bfloat16.hpp" +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool validate = false; +int testIterations = 16; +float threshold = 0.01f; + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ +#if 1 + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); +#else + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = c; //1.0f; // + (float)r / numRows + (float)c / numCols; + } + } +#endif +} + +template +static void vnni_matrix( + std::vector &dst, const std::vector &src, + size_t numRows, size_t numCols, size_t factor) +{ + for (size_t r = 0; r < numRows / factor; r++) { + for (size_t c = 0; c < numCols; c++) { + for (size_t k = 0; k < factor; k++) { + dst[r * numCols * factor + c * factor + k] = + src[(r * factor + k) * numCols + c]; + } + } + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = std::fma(static_cast(A[m * K + k]), + static_cast(B[k * N + n]), sum); + } + C[m * N + n] = sum; + } + } +} + +template +int check_results(const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (int i = 0; i < C.size(); ++i) { + float localErr = std::fabs(C[i] - C_ref[i]) / + std::max(std::fabs(C[i]), + std::fabs(C_ref[i])); + err = std::max(localErr, err); + if (localErr >= threshold) { + std::cerr << "Error at index " << i << " (local error " << localErr + << "): Wanted " << C_ref[i] << ", got " << C[i] + << std::endl; + break; + } + } + + return err < 0.001f; +} + +template +static void go_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_naive"}; + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + printf("Finished in %f seconds\n", best); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } +} + +template +static void go_dpas_basic( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_basic"}; + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + printf("Finished in %f seconds\n", best); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } +} + +int main( + int argc, + char** argv ) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperiments [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s\n", + device.getInfo().c_str() ); + + cl::Context context{device}; + cl::CommandQueue queue{context, device}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + std::vector A(matrixSize * matrixSize); + std::vector B(matrixSize * matrixSize); + std::vector B_vnni(matrixSize * matrixSize); + + std::vector C(matrixSize * matrixSize); + std::vector C_ref(matrixSize * matrixSize); + + printf("Initializing source matrices...\n"); + fill_matrix(A, matrixSize, matrixSize); + fill_matrix(B, matrixSize, matrixSize); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A, B, matrixSize, matrixSize, matrixSize); + } + + printf("Creating source buffers...\n"); + cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; + cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; + cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C.size() * sizeof(C[0])}; + + printf("Running tests...\n"); + + go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + + go_dpas_basic(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + + printf("Done.\n"); + + return 0; +} diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl new file mode 100644 index 00000000..904432b7 --- /dev/null +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -0,0 +1,86 @@ +float bfloat16_to_float(ushort u) +{ +#if defined(cl_intel_bfloat16_conversions) + return intel_convert_as_bfloat16_float(u); +#else + return as_float(u << 16); +#endif +} + +kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_global_id(1); + int n = get_global_id(0); + + float sum = 0; + for (int k = 0; k < K; k++) { + sum += bfloat16_to_float(A[m * K + k]) * bfloat16_to_float(B[k * N + n]); + } + + C[m * N + n] = sum; +} + +#if defined(cl_intel_subgroup_matrix_multiply_accumulate) + +// M rows x K columns +static int __load_a_row_major_bf16_m1(global ushort* A, int rowStart, int colStart, int stride) +{ + int ret; + + int offset = rowStart * stride + colStart + get_sub_group_local_id() * 2; + + ret = as_int(vload2(0, A + offset)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values and converts to VNNI. +// Stride is in units of elements. +static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + int offset = rowStart * stride + colStart + get_sub_group_local_id(); + +// Note: this could probably use block loads? +#define B_ROWDATA(_k) B[(rowStart + _k) * stride + colStart + get_sub_group_local_id()] + ret.s0 = as_int((ushort2)(B_ROWDATA( 0), B_ROWDATA( 1))); + ret.s1 = as_int((ushort2)(B_ROWDATA( 2), B_ROWDATA( 3))); + ret.s2 = as_int((ushort2)(B_ROWDATA( 4), B_ROWDATA( 5))); + ret.s3 = as_int((ushort2)(B_ROWDATA( 6), B_ROWDATA( 7))); + ret.s4 = as_int((ushort2)(B_ROWDATA( 8), B_ROWDATA( 9))); + ret.s5 = as_int((ushort2)(B_ROWDATA(10), B_ROWDATA(11))); + ret.s6 = as_int((ushort2)(B_ROWDATA(12), B_ROWDATA(13))); + ret.s7 = as_int((ushort2)(B_ROWDATA(14), B_ROWDATA(15))); +#undef B_ROWDATA + + return ret; +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + int aData = __load_a_row_major_bf16_m1(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + C[m * N + n + get_sub_group_local_id()] = sum; +} + +#else + +#pragma message("cl_intel_subgroup_matrix_multiply_accumulate is unsupported!") + +kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) {} + +#endif \ No newline at end of file diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index ad326fd5..6e9995f1 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -88,3 +88,5 @@ if(BUILD_EXTENSION_SAMPLES) add_subdirectory( 13_mutablecommandbuffers ) add_subdirectory( 14_ooqcommandbuffers ) endif() + +add_subdirectory( 99_matrixexperiments ) \ No newline at end of file From 16c343ca04082deba7f016dbdff68e3e7f636bc1 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 4 Jan 2024 13:54:37 -0800 Subject: [PATCH 02/99] improved address arithmetic --- .../99_matrixexperiments/matrix_kernels.cl | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 904432b7..1646f33b 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -44,17 +44,32 @@ static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colS int offset = rowStart * stride + colStart + get_sub_group_local_id(); -// Note: this could probably use block loads? -#define B_ROWDATA(_k) B[(rowStart + _k) * stride + colStart + get_sub_group_local_id()] - ret.s0 = as_int((ushort2)(B_ROWDATA( 0), B_ROWDATA( 1))); - ret.s1 = as_int((ushort2)(B_ROWDATA( 2), B_ROWDATA( 3))); - ret.s2 = as_int((ushort2)(B_ROWDATA( 4), B_ROWDATA( 5))); - ret.s3 = as_int((ushort2)(B_ROWDATA( 6), B_ROWDATA( 7))); - ret.s4 = as_int((ushort2)(B_ROWDATA( 8), B_ROWDATA( 9))); - ret.s5 = as_int((ushort2)(B_ROWDATA(10), B_ROWDATA(11))); - ret.s6 = as_int((ushort2)(B_ROWDATA(12), B_ROWDATA(13))); - ret.s7 = as_int((ushort2)(B_ROWDATA(14), B_ROWDATA(15))); -#undef B_ROWDATA + // Note: this could probably use block loads? + ushort row0 = B[offset]; offset += stride; + ushort row1 = B[offset]; offset += stride; + ushort row2 = B[offset]; offset += stride; + ushort row3 = B[offset]; offset += stride; + ushort row4 = B[offset]; offset += stride; + ushort row5 = B[offset]; offset += stride; + ushort row6 = B[offset]; offset += stride; + ushort row7 = B[offset]; offset += stride; + ushort row8 = B[offset]; offset += stride; + ushort row9 = B[offset]; offset += stride; + ushort row10 = B[offset]; offset += stride; + ushort row11 = B[offset]; offset += stride; + ushort row12 = B[offset]; offset += stride; + ushort row13 = B[offset]; offset += stride; + ushort row14 = B[offset]; offset += stride; + ushort row15 = B[offset]; offset += stride; + + ret.s0 = as_int((ushort2)(row0, row1 )); + ret.s1 = as_int((ushort2)(row2, row3 )); + ret.s2 = as_int((ushort2)(row4, row5 )); + ret.s3 = as_int((ushort2)(row6, row7 )); + ret.s4 = as_int((ushort2)(row8, row9 )); + ret.s5 = as_int((ushort2)(row10, row11)); + ret.s6 = as_int((ushort2)(row12, row13)); + ret.s7 = as_int((ushort2)(row14, row15)); return ret; } From 1da0c2c909532b4f60c01acbf4c419d5f7091720 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 4 Jan 2024 16:27:39 -0800 Subject: [PATCH 03/99] added vnni versions --- samples/99_matrixexperiments/main.cpp | 375 ++++++++++++++++-- .../99_matrixexperiments/matrix_kernels.cl | 300 ++++++++++++-- 2 files changed, 610 insertions(+), 65 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 60420738..d3808be5 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -18,6 +18,7 @@ using test_clock = std::chrono::high_resolution_clock; +bool fixedData = false; bool validate = false; int testIterations = 16; float threshold = 0.01f; @@ -25,18 +26,18 @@ float threshold = 0.01f; template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { -#if 1 - std::random_device dev; - std::mt19937 rng(dev()); - std::uniform_real_distribution dist(-1.0, 1.0); - std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); -#else - for (size_t r = 0; r < numRows; r++) { - for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = c; //1.0f; // + (float)r / numRows + (float)c / numCols; + if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = r + c; + } } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); } -#endif } template @@ -100,7 +101,7 @@ static void go_naive( size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); cl::Kernel kernel{program, "bfloat16_naive"}; kernel.setArg(0, C); @@ -117,7 +118,8 @@ static void go_naive( std::chrono::duration elapsed_seconds = end - start; best = std::min(best, elapsed_seconds.count()); } - printf("Finished in %f seconds\n", best); + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); if (validate) { printf("Checking results... "); fflush(stdout); @@ -129,37 +131,322 @@ static void go_naive( } template -static void go_dpas_basic( +static void go_dpas_rowmajor_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m1"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} - cl::Kernel kernel{program, "bfloat16_dpas_basic"}; - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); +template +static void go_dpas_rowmajor_m2( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m2"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); +template +static void go_dpas_rowmajor_m4( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m4"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); } - printf("Finished in %f seconds\n", best); +} - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); +template +static void go_dpas_rowmajor_m8( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m8"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m1( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m1"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m2( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m2"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m4( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m4"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_m8( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + + cl::Kernel kernel{program, "bfloat16_dpas_vnni_m8"}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Finished in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); } } @@ -183,6 +470,7 @@ int main( op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "fixed", "Use Fixed Data", &fixedData); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); bool printUsage = false; try { @@ -213,6 +501,11 @@ int main( printf("Running on device: %s\n", device.getInfo().c_str() ); + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + cl::Context context{device}; cl::CommandQueue queue{context, device}; @@ -242,6 +535,8 @@ int main( fill_matrix(A, matrixSize, matrixSize); fill_matrix(B, matrixSize, matrixSize); + vnni_matrix(B_vnni, B, matrixSize, matrixSize, 2); + if (validate) { printf("Computing reference...\n"); compute_reference(C_ref, A, B, matrixSize, matrixSize, matrixSize); @@ -250,13 +545,21 @@ int main( printf("Creating source buffers...\n"); cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; + cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()}; cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C.size() * sizeof(C[0])}; printf("Running tests...\n"); go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); - - go_dpas_basic(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m1(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m2(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m4(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m8(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + + go_dpas_vnni_m1(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m2(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m4(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni_m8(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 1646f33b..ed4900f5 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -24,13 +24,63 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, #if defined(cl_intel_subgroup_matrix_multiply_accumulate) // M rows x K columns -static int __load_a_row_major_bf16_m1(global ushort* A, int rowStart, int colStart, int stride) +static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) { int ret; - int offset = rowStart * stride + colStart + get_sub_group_local_id() * 2; + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + ret = intel_sub_group_block_read(A_ui + offset_ui); - ret = as_int(vload2(0, A + offset)); + return ret; +} + +// M rows x K columns +static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; return ret; } @@ -42,25 +92,24 @@ static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colS { int8 ret; - int offset = rowStart * stride + colStart + get_sub_group_local_id(); - - // Note: this could probably use block loads? - ushort row0 = B[offset]; offset += stride; - ushort row1 = B[offset]; offset += stride; - ushort row2 = B[offset]; offset += stride; - ushort row3 = B[offset]; offset += stride; - ushort row4 = B[offset]; offset += stride; - ushort row5 = B[offset]; offset += stride; - ushort row6 = B[offset]; offset += stride; - ushort row7 = B[offset]; offset += stride; - ushort row8 = B[offset]; offset += stride; - ushort row9 = B[offset]; offset += stride; - ushort row10 = B[offset]; offset += stride; - ushort row11 = B[offset]; offset += stride; - ushort row12 = B[offset]; offset += stride; - ushort row13 = B[offset]; offset += stride; - ushort row14 = B[offset]; offset += stride; - ushort row15 = B[offset]; offset += stride; + int offset = rowStart * stride + colStart; + + ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row2 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row3 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row4 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row5 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row6 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row7 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row8 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row9 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row10 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row11 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row12 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row13 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row14 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row15 = intel_sub_group_block_read_us(B + offset); offset += stride; ret.s0 = as_int((ushort2)(row0, row1 )); ret.s1 = as_int((ushort2)(row2, row3 )); @@ -74,9 +123,82 @@ static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colS return ret; } +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +static int8 __load_b_vnni_bf16_k16(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + int offset_ui = rowStart / 2 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +static void __store_c_row_major_fp32_m1(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +static void __store_c_row_major_fp32_m2(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +static void __store_c_row_major_fp32_m4(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -84,18 +206,138 @@ kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort float sum = 0; for (int k = 0; k < K; k += 16) { - int aData = __load_a_row_major_bf16_m1(A, m, k, K); + int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); } - C[m * N + n + get_sub_group_local_id()] = sum; + __store_c_row_major_fp32_m1(C, sum, m, n, N); } -#else +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); -#pragma message("cl_intel_subgroup_matrix_multiply_accumulate is unsupported!") + float sum = 0; + for (int k = 0; k < K; k += 16) { + int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} -kernel void bfloat16_dpas_basic(global float* C, global ushort* A, global ushort* B, int K) {} +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} -#endif \ No newline at end of file +#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) \ No newline at end of file From 11e0eef0937a8e009bad9c09d3a66b4a62e4f34c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 4 Jan 2024 16:41:03 -0800 Subject: [PATCH 04/99] cleanup --- samples/99_matrixexperiments/main.cpp | 74 ++++++++----------- .../99_matrixexperiments/matrix_kernels.cl | 4 +- 2 files changed, 34 insertions(+), 44 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index d3808be5..26d177f1 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -79,9 +79,9 @@ int check_results(const std::vector& C, { float err = 0.f; for (int i = 0; i < C.size(); ++i) { - float localErr = std::fabs(C[i] - C_ref[i]) / - std::max(std::fabs(C[i]), - std::fabs(C_ref[i])); + auto localErr = std::fabs(C[i] - C_ref[i]) / + std::max(std::fabs(C[i]), + std::fabs(C_ref[i])); err = std::max(localErr, err); if (localErr >= threshold) { std::cerr << "Error at index " << i << " (local error " << localErr @@ -94,12 +94,11 @@ int check_results(const std::vector& C, return err < 0.001f; } -template static void go_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -123,19 +122,18 @@ static void go_naive( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } } -template static void go_dpas_rowmajor_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -160,8 +158,8 @@ static void go_dpas_rowmajor_m1( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -170,12 +168,11 @@ static void go_dpas_rowmajor_m1( } } -template static void go_dpas_rowmajor_m2( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -200,8 +197,8 @@ static void go_dpas_rowmajor_m2( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -210,12 +207,11 @@ static void go_dpas_rowmajor_m2( } } -template static void go_dpas_rowmajor_m4( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -240,8 +236,8 @@ static void go_dpas_rowmajor_m4( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -250,12 +246,11 @@ static void go_dpas_rowmajor_m4( } } -template static void go_dpas_rowmajor_m8( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -280,8 +275,8 @@ static void go_dpas_rowmajor_m8( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -290,12 +285,11 @@ static void go_dpas_rowmajor_m8( } } -template static void go_dpas_vnni_m1( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -320,8 +314,8 @@ static void go_dpas_vnni_m1( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -330,12 +324,11 @@ static void go_dpas_vnni_m1( } } -template static void go_dpas_vnni_m2( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -360,8 +353,8 @@ static void go_dpas_vnni_m2( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -370,12 +363,11 @@ static void go_dpas_vnni_m2( } } -template static void go_dpas_vnni_m4( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -400,8 +392,8 @@ static void go_dpas_vnni_m4( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -410,12 +402,11 @@ static void go_dpas_vnni_m4( } } -template static void go_dpas_vnni_m8( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, - const std::vector& C_ref) + const std::vector& C_ref) { printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); @@ -440,8 +431,8 @@ static void go_dpas_vnni_m8( if (validate) { printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(T), C_check.data()); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); check_results(C_check, C_ref); printf(" done!\n"); } @@ -450,9 +441,7 @@ static void go_dpas_vnni_m8( } } -int main( - int argc, - char** argv ) +int main(int argc, char** argv) { int platformIndex = 0; int deviceIndex = 0; @@ -551,6 +540,7 @@ int main( printf("Running tests...\n"); go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor_m1(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); go_dpas_rowmajor_m2(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); go_dpas_rowmajor_m4(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index ed4900f5..76ce6753 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -15,7 +15,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, float sum = 0; for (int k = 0; k < K; k++) { - sum += bfloat16_to_float(A[m * K + k]) * bfloat16_to_float(B[k * N + n]); + sum = fma(bfloat16_to_float(A[m * K + k]), bfloat16_to_float(B[k * N + n]), sum); } C[m * N + n] = sum; @@ -340,4 +340,4 @@ kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global usho __store_c_row_major_fp32_m8(C, sum, m, n, N); } -#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) \ No newline at end of file +#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) From d637ee6455578c6079407f4cdbb9102c340f7808 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 5 Jan 2024 13:23:26 -0800 Subject: [PATCH 05/99] host code cleanup --- samples/99_matrixexperiments/main.cpp | 328 +++++--------------------- 1 file changed, 64 insertions(+), 264 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 26d177f1..662457d0 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include "bfloat16.hpp" #include "util.hpp" @@ -23,6 +24,28 @@ bool validate = false; int testIterations = 16; float threshold = 0.01f; +std::string makeTestName( + const std::string &func, + int tM, int tN, int tK, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { @@ -100,7 +123,7 @@ static void go_naive( size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); cl::Kernel kernel{program, "bfloat16_naive"}; kernel.setArg(0, C); @@ -108,6 +131,8 @@ static void go_naive( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + float best = 999.0f; for (int test = 0; test < testIterations; test++) { auto start = test_clock::now(); @@ -118,7 +143,7 @@ static void go_naive( best = std::min(best, elapsed_seconds.count()); } auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); + printf("Best in %f seconds (%f gops)\n", best, gops); if (validate) { printf("Checking results... "); fflush(stdout); @@ -129,15 +154,17 @@ static void go_naive( } } -static void go_dpas_rowmajor_m1( +template +static void go_dpas_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m1"}; + std::string kernelName = "bfloat16_dpas_rowmajor_m" + std::to_string(tM); + cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); kernel.setArg(1, A); @@ -147,14 +174,14 @@ static void go_dpas_rowmajor_m1( float best = 999.0f; for (int test = 0; test < testIterations; test++) { auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); queue.finish(); auto end = test_clock::now(); std::chrono::duration elapsed_seconds = end - start; best = std::min(best, elapsed_seconds.count()); } auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); + printf("Best in %f seconds (%f gops)\n", best, gops); if (validate) { printf("Checking results... "); fflush(stdout); @@ -168,266 +195,36 @@ static void go_dpas_rowmajor_m1( } } -static void go_dpas_rowmajor_m2( +template +static void go_dpas_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m2"}; + std::string kernelName = "bfloat16_dpas_vnni_m" + std::to_string(tM); + cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_rowmajor_m4( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m4"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); float best = 999.0f; for (int test = 0; test < testIterations; test++) { auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); queue.finish(); auto end = test_clock::now(); std::chrono::duration elapsed_seconds = end - start; best = std::min(best, elapsed_seconds.count()); } auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_rowmajor_m8( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_rowmajor_m8"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_vnni_m1( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_vnni_m1"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_vnni_m2( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_vnni_m2"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/2}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_vnni_m4( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_vnni_m4"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/4}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); - printf(" done!\n"); - } - } else { - printf("unsupported.\n"); - } -} - -static void go_dpas_vnni_m8( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%40s (M=%zu, N = %zu, K = %zu): ", __FUNCTION__, M, N, K); fflush(stdout); - - cl::Kernel kernel{program, "bfloat16_dpas_vnni_m8"}; - if (kernel()) { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/8}); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Finished in %f seconds (%f gops)\n", best, gops); + printf("Best in %f seconds (%f gops)\n", best, gops); if (validate) { printf("Checking results... "); fflush(stdout); @@ -513,43 +310,46 @@ int main(int argc, char** argv) program.getBuildInfo(device).c_str() ); } - std::vector A(matrixSize * matrixSize); - std::vector B(matrixSize * matrixSize); - std::vector B_vnni(matrixSize * matrixSize); + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A(M * K); + std::vector B(K * N); + std::vector B_vnni(K * N); - std::vector C(matrixSize * matrixSize); - std::vector C_ref(matrixSize * matrixSize); + std::vector C_ref(M * N); printf("Initializing source matrices...\n"); - fill_matrix(A, matrixSize, matrixSize); - fill_matrix(B, matrixSize, matrixSize); + fill_matrix(A, M, K); + fill_matrix(B, K, N); - vnni_matrix(B_vnni, B, matrixSize, matrixSize, 2); + vnni_matrix(B_vnni, B, K, N, 2); if (validate) { printf("Computing reference...\n"); - compute_reference(C_ref, A, B, matrixSize, matrixSize, matrixSize); + compute_reference(C_ref, A, B, M, N, K); } printf("Creating source buffers...\n"); cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()}; - cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C.size() * sizeof(C[0])}; + cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; printf("Running tests...\n"); - go_naive(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_naive(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor_m1(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_rowmajor_m2(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_rowmajor_m4(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_rowmajor_m8(context, program, queue, Cbuf, Abuf, Bbuf, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_rowmajor<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_dpas_rowmajor<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_dpas_rowmajor<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_dpas_rowmajor<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_vnni_m1(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_vnni_m2(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_vnni_m4(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); - go_dpas_vnni_m8(context, program, queue, Cbuf, Abuf, Bbuf_vnni, matrixSize, matrixSize, matrixSize, C_ref); + go_dpas_vnni<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); + go_dpas_vnni<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); + go_dpas_vnni<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); + go_dpas_vnni<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); printf("Done.\n"); From 52b9550042b1754421ac84e50bb70155dae4e937 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 6 Jan 2024 11:14:38 -0800 Subject: [PATCH 06/99] add SIMD16 versions and emulation --- include/CL/opencl.hpp | 5 + samples/99_matrixexperiments/main.cpp | 88 +++- .../99_matrixexperiments/matrix_kernels.cl | 395 +++++++++++++++++- 3 files changed, 443 insertions(+), 45 deletions(-) diff --git a/include/CL/opencl.hpp b/include/CL/opencl.hpp index 1c43ae0e..c14c81c0 100644 --- a/include/CL/opencl.hpp +++ b/include/CL/opencl.hpp @@ -1654,6 +1654,11 @@ CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_NUM_THREADS_PER_EU_INTEL, CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_FEATURE_CAPABILITIES_INTEL, cl_device_feature_capabilities_intel) #endif // cl_intel_device_attribute_query +#if defined(cl_intel_required_subgroup_size) +CL_HPP_DECLARE_PARAM_TRAITS_(cl_device_info, CL_DEVICE_SUB_GROUP_SIZES_INTEL, cl::vector) +CL_HPP_DECLARE_PARAM_TRAITS_(cl_kernel_work_group_info, CL_KERNEL_SPILL_MEM_SIZE_INTEL, cl_ulong) +#endif // cl_intel_required_subgroup_size + // Convenience functions template diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 662457d0..0d78b421 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ using test_clock = std::chrono::high_resolution_clock; bool fixedData = false; bool validate = false; +bool emulate = false; int testIterations = 16; float threshold = 0.01f; @@ -46,6 +48,16 @@ std::string makeTestName( return ret.str(); } +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { @@ -163,7 +175,9 @@ static void go_dpas_rowmajor( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_rowmajor_m" + std::to_string(tM); + std::string kernelName = "bfloat16_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); @@ -204,7 +218,9 @@ static void go_dpas_vnni( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_vnni_m" + std::to_string(tM); + std::string kernelName = "bfloat16_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { kernel.setArg(0, C); @@ -257,6 +273,7 @@ int main(int argc, char** argv) op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); bool printUsage = false; try { @@ -287,10 +304,27 @@ int main(int argc, char** argv) printf("Running on device: %s\n", device.getInfo().c_str() ); + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + auto minSubGroupSize = findMinSubGroupSize(device); + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + printf("Config:\n"); printf("\tTest Iterations: %d\n", testIterations); printf("\tValidating data?: %s\n", validate ? "true" : "false"); printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); cl::Context context{device}; cl::CommandQueue queue{context, device}; @@ -314,42 +348,52 @@ int main(int argc, char** argv) const auto N = matrixSize; const auto K = matrixSize; - std::vector A(M * K); - std::vector B(K * N); - std::vector B_vnni(K * N); + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); std::vector C_ref(M * N); printf("Initializing source matrices...\n"); - fill_matrix(A, M, K); - fill_matrix(B, K, N); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); - vnni_matrix(B_vnni, B, K, N, 2); + vnni_matrix(Bvnni_vec, B_vec, K, N, 2); if (validate) { printf("Computing reference...\n"); - compute_reference(C_ref, A, B, M, N, K); + compute_reference(C_ref, A_vec, B_vec, M, N, K); } printf("Creating source buffers...\n"); - cl::Buffer Abuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A.size() * sizeof(A[0]), A.data()}; - cl::Buffer Bbuf{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B.size() * sizeof(B[0]), B.data()}; - cl::Buffer Bbuf_vnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vnni.size() * sizeof(B_vnni[0]), B_vnni.data()}; - cl::Buffer Cbuf{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; printf("Running tests...\n"); - go_naive(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_naive(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_rowmajor<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); - go_dpas_rowmajor<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf, M, N, K, C_ref); + go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_vnni<1, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<2, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<4, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); - go_dpas_vnni<8, 8, 16>(context, program, queue, Cbuf, Abuf, Bbuf_vnni, M, N, K, C_ref); + go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 76ce6753..71210216 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,4 +1,16 @@ -float bfloat16_to_float(ushort u) +#if EMULATE_tn8 == 0 +#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x8 my_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +#if EMULATE_tN16 == 0 +#define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x16 my_sub_group_bf16_bf16_matrix_mad_k16 +#endif + +float bf16_to_fp32(ushort u) { #if defined(cl_intel_bfloat16_conversions) return intel_convert_as_bfloat16_float(u); @@ -15,15 +27,146 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, float sum = 0; for (int k = 0; k < K; k++) { - sum = fma(bfloat16_to_float(A[m * K + k]), bfloat16_to_float(B[k * N + n]), sum); + sum = fma(bf16_to_fp32(A[m * K + k]), bf16_to_fp32(B[k * N + n]), sum); } C[m * N + n] = sum; } -#if defined(cl_intel_subgroup_matrix_multiply_accumulate) +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#define OVLD __attribute__((overloadable)) + +// SIMD8 versions: +static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).x), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).y), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).x), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).y), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).x), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).y), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).x), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).y), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).x), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).y), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).x), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).y), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).x), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).y), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).x), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).y), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// SIMD16 versions: +static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +#undef OVLD // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) { int ret; @@ -32,10 +175,11 @@ static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int int offset_ui = rowStart * stride / 2 + colStart / 2; ret = intel_sub_group_block_read(A_ui + offset_ui); - return ret; + return ret; } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, int colStart, int stride) { int2 ret; @@ -46,10 +190,11 @@ static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, in ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - return ret; + return ret; } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, int colStart, int stride) { int4 ret; @@ -66,6 +211,7 @@ static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, in } // M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, int colStart, int stride) { int8 ret; @@ -82,7 +228,66 @@ static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, in ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - return ret; + return ret; +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short __load_a_row_major_bf16_k16_m1_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + int offset = rowStart * stride + colStart; + ret = intel_sub_group_block_read_us(A + offset); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short2 __load_a_row_major_bf16_k16_m2_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short4 __load_a_row_major_bf16_k16_m4_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +static short8 __load_a_row_major_bf16_k16_m8_x16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s4 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s5 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s6 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s7 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short8(ret); } // K rows x N columns: @@ -198,7 +403,7 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -208,7 +413,7 @@ kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m1(C, sum, m, n, N); @@ -216,7 +421,7 @@ kernel void bfloat16_dpas_rowmajor_m1(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 2; @@ -226,7 +431,7 @@ kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m2(C, sum, m, n, N); @@ -234,7 +439,7 @@ kernel void bfloat16_dpas_rowmajor_m2(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 4; @@ -244,7 +449,7 @@ kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m4(C, sum, m, n, N); @@ -252,7 +457,7 @@ kernel void bfloat16_dpas_rowmajor_m4(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 8; @@ -262,7 +467,79 @@ kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global for (int k = 0; k < K; k += 16) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); } __store_c_row_major_fp32_m8(C, sum, m, n, N); @@ -270,7 +547,7 @@ kernel void bfloat16_dpas_rowmajor_m8(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1); @@ -280,7 +557,7 @@ kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m1(C, sum, m, n, N); @@ -288,7 +565,7 @@ kernel void bfloat16_dpas_vnni_m1(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 2; @@ -298,7 +575,7 @@ kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m2(C, sum, m, n, N); @@ -306,7 +583,7 @@ kernel void bfloat16_dpas_vnni_m2(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 4; @@ -316,7 +593,7 @@ kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); } __store_c_row_major_fp32_m4(C, sum, m, n, N); @@ -324,7 +601,7 @@ kernel void bfloat16_dpas_vnni_m4(global float* C, global ushort* A, global usho __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) -kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); int m = get_group_id(1) * 8; @@ -334,10 +611,82 @@ kernel void bfloat16_dpas_vnni_m8(global float* C, global ushort* A, global usho for (int k = 0; k < K; k += 16) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); - sum = intel_sub_group_bf16_bf16_matrix_mad_k16(aData, bData, sum); + sum = mat_mul_x8(aData, bData, sum); + } + + __store_c_row_major_fp32_m8(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m1(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m2(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + __store_c_row_major_fp32_m4(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); } __store_c_row_major_fp32_m8(C, sum, m, n, N); } -#endif // defined(cl_intel_subgroup_matrix_multiply_accumulate) +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) From 074e0a5e984ce55d94ed15f172e26782974e1564 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 6 Jan 2024 11:32:31 -0800 Subject: [PATCH 07/99] add support for PVC, which does not support SIMD8 --- samples/99_matrixexperiments/main.cpp | 7 +++++-- samples/99_matrixexperiments/matrix_kernels.cl | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 0d78b421..32deef14 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -304,18 +304,21 @@ int main(int argc, char** argv) printf("Running on device: %s\n", device.getInfo().c_str() ); + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_simd8 = minSubGroupSize == 8; bool emulate_tN8 = true; bool emulate_tN16 = true; if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { - auto minSubGroupSize = findMinSubGroupSize(device); printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); switch(minSubGroupSize) { - case 8: emulate_tN8 = false; break; + case 8: emulate_tN8 = false; break; case 16: emulate_tN16 = false; break; default: break; } } + buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 71210216..009225a7 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -401,6 +401,8 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; } +#if HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -473,6 +475,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#endif // HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) @@ -545,6 +549,8 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#if HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -617,6 +623,8 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#endif // HAS_SIMD8 + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) From 1ca8f73c322fb72b17fff3075ae2e0074279026d Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 11:00:21 -0800 Subject: [PATCH 08/99] fix warning --- samples/99_matrixexperiments/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 32deef14..5d72ecec 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -64,7 +64,7 @@ static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) if (fixedData) { for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) { - M[r * numCols + c] = r + c; + M[r * numCols + c] = static_cast(r + c); } } } else { From f2b00f3f6599dc67fc385cda07326d766badf476 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 19:01:13 -0800 Subject: [PATCH 09/99] add 2D block read variants --- samples/99_matrixexperiments/CMakeLists.txt | 2 +- samples/99_matrixexperiments/main.cpp | 91 +++++++-- .../99_matrixexperiments/matrix_kernels.cl | 180 +++++++++++++++++- 3 files changed, 248 insertions(+), 25 deletions(-) diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt index 7f2a696d..6020ec83 100644 --- a/samples/99_matrixexperiments/CMakeLists.txt +++ b/samples/99_matrixexperiments/CMakeLists.txt @@ -4,7 +4,7 @@ add_opencl_sample( TEST - NUMBER 05 + NUMBER 99 TARGET matrixexperiments VERSION 120 SOURCES main.cpp diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 5d72ecec..8c299a7d 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -64,6 +64,7 @@ static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) if (fixedData) { for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) { + //M[r * numCols + c] = 1.0f; M[r * numCols + c] = static_cast(r + c); } } @@ -254,6 +255,49 @@ static void go_dpas_vnni( } } +template +static void go_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + int main(int argc, char** argv) { int platformIndex = 0; @@ -376,27 +420,32 @@ int main(int argc, char** argv) printf("Running tests...\n"); - go_naive(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_naive(context, program, queue, C, A, B, M, N, K, C_ref); + // + //go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + // + //go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + // + //go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + //go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + // + //go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + //go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 009225a7..56086a84 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,3 +1,5 @@ +#define OVLD __attribute__((overloadable)) + #if EMULATE_tn8 == 0 #define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 #else @@ -35,7 +37,8 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) -#define OVLD __attribute__((overloadable)) +// These are non-block read versions. +// They work on DG2 and PVC, and on other devices when emulated. // SIMD8 versions: static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) @@ -163,8 +166,6 @@ static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float return res; } -#undef OVLD - // M rows x K columns // This is the SIMD8 version, where each work-item loads two values. static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) @@ -697,4 +698,177 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#ifdef cl_intel_subgroup_extended_block_read + +// Note for 2D block reads: +// - the tile width and height is encoded into the function name. +// - base_address is the byte address. Must be 64B aligned. +// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. +// - height is the height of the entire matrix, or equivalently the number of rows. +// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. +// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. + +// Built-in functions are: + +// #ifdef cl_intel_subgroup_extended_block_read +// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); +// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); +// #endif //defined(cl_intel_subgroup_extended_block_read) + + +// For intrinsics, the pattern is: +// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat +// - operation (optional): _transpose or _transform +// - for no transpose or transform: +// - type / elements size: _u8 or _u16 or _u32 or _u64 +// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 +// - tile width: _k64 or _k32 or _k16 or _k8 +// - number of tiles: _v2 or _v1 +// - for transpose: +// - type / element size: _u64 or _u32 +// - number of tile rows: subgroup size (16) +// - tile width: _k4 (for _u64) or _k8 (for _u32) +// - number of tiles: 1 +// - for transform: +// - type / element size: _u16 or _u8 +// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) +// - tile width: subgroup size (16) +// - number of tiles: 1 + +// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: + +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); + +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m2k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m4k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1); + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 2; + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 4; + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 8; + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +#endif // cl_intel_subgroup_extended_block_read + #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#undef OVLD From 0bb5529e567aa2a528d9032df77c2466c66125ee Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 19:02:04 -0800 Subject: [PATCH 10/99] reenable all variants --- samples/99_matrixexperiments/main.cpp | 42 +++++++++++++-------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 8c299a7d..4fd60b88 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -420,27 +420,27 @@ int main(int argc, char** argv) printf("Running tests...\n"); - //go_naive(context, program, queue, C, A, B, M, N, K, C_ref); - // - //go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - // - //go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - // - //go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - //go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - // - //go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - //go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_naive(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); From 7b89cfe1c60af25be196655ea02f697a6098e3c5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 21:39:24 -0800 Subject: [PATCH 11/99] add vnni block read variants --- samples/99_matrixexperiments/main.cpp | 55 +++++++++++- .../99_matrixexperiments/matrix_kernels.cl | 83 +++++++++++++++++++ 2 files changed, 136 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 4fd60b88..aeb132c5 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -20,6 +20,7 @@ using test_clock = std::chrono::high_resolution_clock; +bool identityData = false; bool fixedData = false; bool validate = false; bool emulate = false; @@ -61,10 +62,11 @@ static size_t findMinSubGroupSize(cl::Device& device) template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { - if (fixedData) { + if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1.0f; }); + } else if (fixedData) { for (size_t r = 0; r < numRows; r++) { for (size_t c = 0; c < numCols; c++) { - //M[r * numCols + c] = 1.0f; M[r * numCols + c] = static_cast(r + c); } } @@ -298,6 +300,49 @@ static void go_dpas_blockread_rowmajor( } } +template +static void go_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + int main(int argc, char** argv) { int platformIndex = 0; @@ -316,6 +361,7 @@ int main(int argc, char** argv) op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "identity", "Use Identity Data", &identityData); op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); @@ -447,6 +493,11 @@ int main(int argc, char** argv) go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + printf("Done.\n"); return 0; diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 56086a84..53c618f9 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -752,6 +752,8 @@ ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int w ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); @@ -774,6 +776,11 @@ ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, i return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); @@ -867,6 +874,82 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1); + const int N = get_global_size(0); + int m = get_group_id(1); + int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += 16) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 2; + const int N = get_global_size(0); + int m = get_group_id(1) * 2; + int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += 16) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 4; + const int N = get_global_size(0); + int m = get_group_id(1) * 4; + int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += 16) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) +__attribute__((reqd_work_group_size(16, 1, 1))) +kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int M = get_global_size(1) * 8; + const int N = get_global_size(0); + int m = get_group_id(1) * 8; + int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += 16) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_x16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + #endif // cl_intel_subgroup_extended_block_read #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) From ca4b3cd433dd626f6bdfd4bf097fc9fd0472bdb5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 8 Jan 2024 21:49:03 -0800 Subject: [PATCH 12/99] fix typo in emulation path --- samples/99_matrixexperiments/matrix_kernels.cl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 53c618f9..ed3413fe 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,15 +1,15 @@ #define OVLD __attribute__((overloadable)) -#if EMULATE_tn8 == 0 -#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 -#else +#if EMULATE_tN8 #define mat_mul_x8 my_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif -#if EMULATE_tN16 == 0 -#define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 -#else +#if EMULATE_tN16 #define mat_mul_x16 my_sub_group_bf16_bf16_matrix_mad_k16 +#else +#define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif float bf16_to_fp32(ushort u) From 3ffcf3e51dacb12c3768b3810aca4e422800d930 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 13:10:48 -0800 Subject: [PATCH 13/99] start to add block tiled versions --- samples/99_matrixexperiments/main.cpp | 151 +++++++++++++++--- .../99_matrixexperiments/matrix_kernels.cl | 70 ++++++++ 2 files changed, 202 insertions(+), 19 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index aeb132c5..4912d368 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -27,6 +27,16 @@ bool emulate = false; int testIterations = 16; float threshold = 0.01f; +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + std::string makeTestName( const std::string &func, int tM, int tN, int tK, @@ -41,10 +51,13 @@ std::string makeTestName( std::string makeTestName( const std::string &func, + int tM, int tN, int tK, + int MM, int NN, size_t M, size_t N, size_t K) { std::ostringstream ret; ret << func; + ret << ""; ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; return ret.str(); } @@ -112,24 +125,28 @@ static void compute_reference( } template -int check_results(const std::vector& C, - const std::vector& C_ref) +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) { float err = 0.f; - for (int i = 0; i < C.size(); ++i) { - auto localErr = std::fabs(C[i] - C_ref[i]) / - std::max(std::fabs(C[i]), - std::fabs(C_ref[i])); - err = std::max(localErr, err); - if (localErr >= threshold) { - std::cerr << "Error at index " << i << " (local error " << localErr - << "): Wanted " << C_ref[i] << ", got " << C[i] - << std::endl; - break; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + auto localErr = std::fabs(C[index] - C_ref[index]) / + std::max(std::fabs(C[index]), + std::fabs(C_ref[index])); + err = std::max(localErr, err); + if (localErr >= threshold) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (local error " << localErr << "): Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } } } - - return err < 0.001f; } static void go_naive( @@ -164,7 +181,7 @@ static void go_naive( printf("Checking results... "); fflush(stdout); std::vector C_check(C_ref.size()); queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); + check_results(M, N, C_check, C_ref); printf(" done!\n"); } } @@ -204,7 +221,52 @@ static void go_dpas_rowmajor( printf("Checking results... "); fflush(stdout); std::vector C_check(C_ref.size()); queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_rowmajor_x( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "x" + std::to_string(MM); + kernelName += "_n" + std::to_string(tN); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N/NN, M/tM/MM}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); printf(" done!\n"); } } else { @@ -249,7 +311,54 @@ static void go_dpas_vnni( printf("Checking results... "); fflush(stdout); std::vector C_check(C_ref.size()); queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } else { + printf("unsupported.\n"); + } +} + +template +static void go_dpas_vnni_x( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "x" + std::to_string(MM); + kernelName += "_n" + std::to_string(tN); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel()) { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N/NN, M/tM/MM}); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration elapsed_seconds = end - start; + best = std::min(best, elapsed_seconds.count()); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); printf(" done!\n"); } } else { @@ -292,7 +401,7 @@ static void go_dpas_blockread_rowmajor( printf("Checking results... "); fflush(stdout); std::vector C_check(C_ref.size()); queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); + check_results(M, N, C_check, C_ref); printf(" done!\n"); } } else { @@ -335,7 +444,7 @@ static void go_dpas_blockread_vnni( printf("Checking results... "); fflush(stdout); std::vector C_check(C_ref.size()); queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(C_check, C_ref); + check_results(M, N, C_check, C_ref); printf(" done!\n"); } } else { @@ -473,11 +582,15 @@ int main(int argc, char** argv) go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_x<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_x<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index ed3413fe..d2b26bae 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -476,6 +476,41 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m8(C, sum, m, n, N); } +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 2 + + const int N = get_global_size(0); + int m = get_group_id(1) * 8 * MM; + int n = get_group_id(0) * get_local_size(0); + + float8 sum[MM]; + for (int mm = 0; mm < MM; mm++) { + sum[mm] = 0; + } + + for (int k = 0; k < K; k += 16) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * 8, k, K); + } + + int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + + for (int mm = 0; mm < MM; mm++) { + sum[mm] = mat_mul_x8(aData[mm], bData, sum[mm]); + } + } + + for (int mm = 0; mm < MM; mm++) { + __store_c_row_major_fp32_m8(C, sum[mm], m + mm * 8, n, N); + } + + #undef MM +} + #endif // HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(16))) @@ -624,6 +659,41 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m8(C, sum, m, n, N); } +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 2 + + const int N = get_global_size(0); + int m = get_group_id(1) * 8 * MM; + int n = get_group_id(0) * get_local_size(0); + + float8 sum[MM]; + for (int mm = 0; mm < MM; mm++) { + sum[mm] = 0; + } + + for (int k = 0; k < K; k += 16) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * 8, k, K); + } + + int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + + for (int mm = 0; mm < MM; mm++) { + sum[mm] = mat_mul_x8(aData[mm], bData, sum[mm]); + } + } + + for (int mm = 0; mm < MM; mm++) { + __store_c_row_major_fp32_m8(C, sum[mm], m + mm * 8, n, N); + } + + #undef MM +} + #endif // HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(16))) From b469713d043af9afe589275c6e37a2c867b2882c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 14:04:11 -0800 Subject: [PATCH 14/99] improve block tiled versions --- samples/99_matrixexperiments/main.cpp | 3 + .../99_matrixexperiments/matrix_kernels.cl | 155 ++++++++++++++++-- 2 files changed, 145 insertions(+), 13 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 4912d368..cbcf9b2c 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -583,6 +583,8 @@ int main(int argc, char** argv) go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor_x<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_x<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_x<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -590,6 +592,7 @@ int main(int argc, char** argv) go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni_x<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_x<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index d2b26bae..8fbed3a2 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -480,35 +480,162 @@ __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) { + #define tM 8 + #define tN 8 + #define tK 16 + #define MM 2 + #define NN 1 - const int N = get_global_size(0); - int m = get_group_id(1) * 8 * MM; - int n = get_group_id(0) * get_local_size(0); + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; - float8 sum[MM]; + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { - sum[mm] = 0; + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } } - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * 8, k, K); + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); } - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } for (int mm = 0; mm < MM; mm++) { - sum[mm] = mat_mul_x8(aData[mm], bData, sum[mm]); + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } } } for (int mm = 0; mm < MM; mm++) { - __store_c_row_major_fp32_m8(C, sum[mm], m + mm * 8, n, N); + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef tM + #undef tN + #undef tK + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8x1_n8x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define tM 8 + #define tN 8 + #define tK 16 + + #define MM 1 + #define NN 2 + + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } } + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef tM + #undef tN + #undef tK + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(8))) +__attribute__((reqd_work_group_size(8, 1, 1))) +kernel void bfloat16_dpas_rowmajor_m8x2_n8x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define tM 8 + #define tN 8 + #define tK 16 + + #define MM 2 + #define NN 2 + + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef tM + #undef tN + #undef tK + #undef MM + #undef NN } #endif // HAS_SIMD8 @@ -663,10 +790,11 @@ __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) { + #define tM 8 #define MM 2 const int N = get_global_size(0); - int m = get_group_id(1) * 8 * MM; + int m = get_group_id(1) * tM * MM; int n = get_group_id(0) * get_local_size(0); float8 sum[MM]; @@ -677,7 +805,7 @@ kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, glob for (int k = 0; k < K; k += 16) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * 8, k, K); + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); } int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); @@ -688,9 +816,10 @@ kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, glob } for (int mm = 0; mm < MM; mm++) { - __store_c_row_major_fp32_m8(C, sum[mm], m + mm * 8, n, N); + __store_c_row_major_fp32_m8(C, sum[mm], m + mm * tM, n, N); } + #undef tM #undef MM } From 098f339a0f559514f73a97b6ef5e207e57fc9647 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 14:47:07 -0800 Subject: [PATCH 15/99] more improvements --- samples/99_matrixexperiments/main.cpp | 1 + .../99_matrixexperiments/matrix_kernels.cl | 451 +++++++++++------- 2 files changed, 279 insertions(+), 173 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index cbcf9b2c..4d9f32d5 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -593,6 +593,7 @@ int main(int argc, char** argv) go_dpas_vnni_x<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni_x<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_x<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 8fbed3a2..5ed2cde1 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,5 +1,8 @@ #define OVLD __attribute__((overloadable)) +// For all bfloat16 kernels we have tK == 16: +#define tK 16 + #if EMULATE_tN8 #define mat_mul_x8 my_sub_group_bf16_bf16_matrix_mad_k16 #else @@ -24,8 +27,8 @@ float bf16_to_fp32(ushort u) kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); - int m = get_global_id(1); - int n = get_global_id(0); + const int m = get_global_id(1); + const int n = get_global_id(0); float sum = 0; for (int k = 0; k < K; k++) { @@ -404,16 +407,19 @@ static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, #if HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +// For SIMD8 kernels, tN == 8: +#define tN 8 + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 1; const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -422,16 +428,16 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m1(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 2; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -440,16 +446,16 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m2(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 4; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -458,16 +464,16 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m4(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 8; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -476,17 +482,13 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob __store_c_row_major_fp32_m8(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) { - #define tM 8 - #define tN 8 - #define tK 16 - #define MM 2 #define NN 1 + const int tM = 8; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -522,25 +524,17 @@ kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, } } - #undef tM - #undef tN - #undef tK - #undef MM #undef NN } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8x1_n8x2(global float* C, global ushort* A, global ushort* B, int K) { - #define tM 8 - #define tN 8 - #define tK 16 - #define MM 1 #define NN 2 + const int tM = 8; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -575,26 +569,17 @@ kernel void bfloat16_dpas_rowmajor_m8x1_n8x2(global float* C, global ushort* A, __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } - - #undef tM - #undef tN - #undef tK - #undef MM #undef NN } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8x2_n8x2(global float* C, global ushort* A, global ushort* B, int K) { - #define tM 8 - #define tN 8 - #define tK 16 - #define MM 2 #define NN 2 + const int tM = 8; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -630,26 +615,27 @@ kernel void bfloat16_dpas_rowmajor_m8x2_n8x2(global float* C, global ushort* A, } } - #undef tM - #undef tN - #undef tK - #undef MM #undef NN } +#undef tN + #endif // HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +// For SIMD16 kernels, tN == 16: +#define tN 16 + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 1; const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -658,16 +644,16 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m1(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 2; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tK; + const int n = get_group_id(0) * get_local_size(0); float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -676,16 +662,16 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m2(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 4; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -694,16 +680,16 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m4(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 8; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -712,18 +698,23 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#undef tN + #if HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +// For SIMD8 kernels, tN == 8: +#define tN 8 + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 1; const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -732,16 +723,16 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m1(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 2; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -750,16 +741,16 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m2(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 4; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -768,16 +759,16 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m4(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 8; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); @@ -786,55 +777,161 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u __store_c_row_major_fp32_m8(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(8))) -__attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) { - #define tM 8 #define MM 2 + #define NN 1 - const int N = get_global_size(0); - int m = get_group_id(1) * tM * MM; - int n = get_group_id(0) * get_local_size(0); + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; - float8 sum[MM]; + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { - sum[mm] = 0; + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } } - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); } - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m8x1_n8x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 1 + #define NN 2 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m8x2_n8x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 2 + #define NN 2 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - sum[mm] = mat_mul_x8(aData[mm], bData, sum[mm]); + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } } } for (int mm = 0; mm < MM; mm++) { - __store_c_row_major_fp32_m8(C, sum[mm], m + mm * tM, n, N); + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } } - #undef tM #undef MM + #undef NN } +#undef tN + #endif // HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +// For SIMD16 kernels, tN == 16: +#define tN 16 + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 1; const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -843,16 +940,16 @@ kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global __store_c_row_major_fp32_m1(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 2; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -861,16 +958,16 @@ kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global __store_c_row_major_fp32_m2(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 4; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -879,16 +976,16 @@ kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global __store_c_row_major_fp32_m4(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 8; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); @@ -897,8 +994,13 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global __store_c_row_major_fp32_m8(C, sum, m, n, N); } +#undef tN + #ifdef cl_intel_subgroup_extended_block_read +// All of the block read kernels are SIMD16 kernels, tN == 16: +#define tN 16 + // Note for 2D block reads: // - the tile width and height is encoded into the function name. // - base_address is the byte address. Must be 64B aligned. @@ -997,17 +1099,17 @@ void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int wid __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + const int tM = 1; const int M = get_global_size(1); const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_x16(aData, bData, sum); @@ -1016,17 +1118,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 2; + const int tM = 2; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_x16(aData, bData, sum); @@ -1035,17 +1137,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 4; + const int tM = 4; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_x16(aData, bData, sum); @@ -1054,17 +1156,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 8; + const int tM = 8; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_x16(aData, bData, sum); @@ -1073,17 +1175,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1); + const int tM = 1; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1); - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_x16(aData, bData, sum); @@ -1092,17 +1194,17 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 2; + const int tM = 2; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 2; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float2 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_x16(aData, bData, sum); @@ -1111,17 +1213,17 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 4; + const int tM = 4; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 4; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float4 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_x16(aData, bData, sum); @@ -1130,17 +1232,17 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } -__attribute__((intel_reqd_sub_group_size(16))) -__attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { - const int M = get_global_size(1) * 8; + const int tM = 8; + const int M = get_global_size(1) * tM; const int N = get_global_size(0); - int m = get_group_id(1) * 8; - int n = get_group_id(0) * get_local_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; float8 sum = 0; - for (int k = 0; k < K; k += 16) { + for (int k = 0; k < K; k += tK) { short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_x16(aData, bData, sum); @@ -1149,8 +1251,11 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } +#undef tN + #endif // cl_intel_subgroup_extended_block_read #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #undef OVLD +#undef tK From ce7866f489f46cc0a26c5e9a4a5d234de3241628 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 15:20:38 -0800 Subject: [PATCH 16/99] add more block tiled variants --- samples/99_matrixexperiments/main.cpp | 30 +- .../99_matrixexperiments/matrix_kernels.cl | 288 +++++++++++++++++- 2 files changed, 300 insertions(+), 18 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 4d9f32d5..bf1bfb53 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -230,7 +230,7 @@ static void go_dpas_rowmajor( } template -static void go_dpas_rowmajor_x( +static void go_dpas_rowmajor_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, @@ -238,10 +238,10 @@ static void go_dpas_rowmajor_x( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_rowmajor"; + std::string kernelName = "bfloat16_dpas_rowmajor_tiled"; kernelName += "_m" + std::to_string(tM); - kernelName += "x" + std::to_string(MM); kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); kernelName += "x" + std::to_string(NN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { @@ -320,7 +320,7 @@ static void go_dpas_vnni( } template -static void go_dpas_vnni_x( +static void go_dpas_vnni_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, @@ -328,10 +328,10 @@ static void go_dpas_vnni_x( { printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); - std::string kernelName = "bfloat16_dpas_vnni"; + std::string kernelName = "bfloat16_dpas_vnni_tiled"; kernelName += "_m" + std::to_string(tM); - kernelName += "x" + std::to_string(MM); kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); kernelName += "x" + std::to_string(NN); cl::Kernel kernel{program, kernelName.c_str()}; if (kernel()) { @@ -582,18 +582,24 @@ int main(int argc, char** argv) go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_x<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_x<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_x<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_x<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_x<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_x<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 5ed2cde1..08a82f88 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -483,7 +483,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x1(global float* C, global ushort* A, global ushort* B, int K) { #define MM 2 #define NN 1 @@ -529,7 +529,7 @@ kernel void bfloat16_dpas_rowmajor_m8x2_n8x1(global float* C, global ushort* A, } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8x1_n8x2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_1x2(global float* C, global ushort* A, global ushort* B, int K) { #define MM 1 #define NN 2 @@ -574,7 +574,7 @@ kernel void bfloat16_dpas_rowmajor_m8x1_n8x2(global float* C, global ushort* A, } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8x2_n8x2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x2(global float* C, global ushort* A, global ushort* B, int K) { #define MM 2 #define NN 2 @@ -619,6 +619,144 @@ kernel void bfloat16_dpas_rowmajor_m8x2_n8x2(global float* C, global ushort* A, #undef NN } +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 4 + #define NN 2 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x4(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 2 + #define NN 4 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x4(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 4 + #define NN 4 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + #undef tN #endif // HAS_SIMD8 @@ -778,7 +916,7 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x1(global float* C, global ushort* A, global ushort* B, int K) { #define MM 2 #define NN 1 @@ -824,7 +962,7 @@ kernel void bfloat16_dpas_vnni_m8x2_n8x1(global float* C, global ushort* A, glob } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m8x1_n8x2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_1x2(global float* C, global ushort* A, global ushort* B, int K) { #define MM 1 #define NN 2 @@ -870,7 +1008,7 @@ kernel void bfloat16_dpas_vnni_m8x1_n8x2(global float* C, global ushort* A, glob } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m8x2_n8x2(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x2(global float* C, global ushort* A, global ushort* B, int K) { #define MM 2 #define NN 2 @@ -915,6 +1053,144 @@ kernel void bfloat16_dpas_vnni_m8x2_n8x2(global float* C, global ushort* A, glob #undef NN } +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x2(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 4 + #define NN 2 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x4(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 2 + #define NN 4 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x4(global float* C, global ushort* A, global ushort* B, int K) +{ + #define MM 4 + #define NN 4 + + const int tM = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } + + #undef MM + #undef NN +} + #undef tN #endif // HAS_SIMD8 From 48c3bc21e1a575fcb42775ab8c6a17b3d1407a09 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 18:01:46 -0800 Subject: [PATCH 17/99] refactor device code into a helper file --- samples/99_matrixexperiments/CMakeLists.txt | 2 +- .../99_matrixexperiments/matrix_helpers.cl | 484 +++++++++++ .../99_matrixexperiments/matrix_kernels.cl | 802 ++++-------------- 3 files changed, 644 insertions(+), 644 deletions(-) create mode 100644 samples/99_matrixexperiments/matrix_helpers.cl diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt index 6020ec83..456a8cb1 100644 --- a/samples/99_matrixexperiments/CMakeLists.txt +++ b/samples/99_matrixexperiments/CMakeLists.txt @@ -8,4 +8,4 @@ add_opencl_sample( TARGET matrixexperiments VERSION 120 SOURCES main.cpp - KERNELS matrix_kernels.cl) + KERNELS matrix_helpers.cl matrix_kernels.cl) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl new file mode 100644 index 00000000..ef68fa44 --- /dev/null +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -0,0 +1,484 @@ +float bf16_to_fp32(ushort u) +{ +#if defined(cl_intel_bfloat16_conversions) + return intel_convert_as_bfloat16_float(u); +#else + return as_float(u << 16); +#endif +} + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) + +// Emulated SIMD8 dpas: +__attribute__((overloadable)) +float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).x), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).y), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).x), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).y), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).x), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).y), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).x), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).y), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).x), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).y), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).x), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).y), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).x), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).y), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).x), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).y), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// Emulated SIMD16 dpas: +__attribute__((overloadable)) +float emu_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) +{ + float res = acc; + + res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc) +{ + float2 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc) +{ + float4 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) +{ + float8 res; + + res.s0 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); + res.s1 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); + res.s2 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); + res.s3 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); + res.s4 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); + res.s5 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); + res.s6 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); + res.s7 = emu_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); + + return res; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + ret = intel_sub_group_block_read(A_ui + offset_ui); + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads two values. +int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + int offset = rowStart * stride + colStart; + ret = intel_sub_group_block_read_us(A + offset); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads one values. +short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + int offset = rowStart * stride + colStart; + ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s4 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s5 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s6 = intel_sub_group_block_read_us(A + offset); offset += stride; + ret.s7 = intel_sub_group_block_read_us(A + offset); offset += stride; + + return as_short8(ret); +} + +// K rows x N columns: +// Each work-item loads K values and converts to VNNI. +// Stride is in units of elements. +int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + int offset = rowStart * stride + colStart; + + ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row2 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row3 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row4 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row5 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row6 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row7 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row8 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row9 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row10 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row11 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row12 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row13 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row14 = intel_sub_group_block_read_us(B + offset); offset += stride; + ushort row15 = intel_sub_group_block_read_us(B + offset); offset += stride; + + ret.s0 = as_int((ushort2)(row0, row1 )); + ret.s1 = as_int((ushort2)(row2, row3 )); + ret.s2 = as_int((ushort2)(row4, row5 )); + ret.s3 = as_int((ushort2)(row6, row7 )); + ret.s4 = as_int((ushort2)(row8, row9 )); + ret.s5 = as_int((ushort2)(row10, row11)); + ret.s6 = as_int((ushort2)(row12, row13)); + ret.s7 = as_int((ushort2)(row14, row15)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + int offset_ui = rowStart / 2 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + int offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) + +#ifdef cl_intel_subgroup_extended_block_read + +// Note for 2D block reads: +// - the tile width and height is encoded into the function name. +// - base_address is the byte address. Must be 64B aligned. +// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. +// - height is the height of the entire matrix, or equivalently the number of rows. +// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. +// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. + +// Built-in functions are: + +// #ifdef cl_intel_subgroup_extended_block_read +// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); +// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); +// #endif //defined(cl_intel_subgroup_extended_block_read) + + +// For intrinsics, the pattern is: +// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat +// - operation (optional): _transpose or _transform +// - for no transpose or transform: +// - type / elements size: _u8 or _u16 or _u32 or _u64 +// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 +// - tile width: _k64 or _k32 or _k16 or _k8 +// - number of tiles: _v2 or _v1 +// - for transpose: +// - type / element size: _u64 or _u32 +// - number of tile rows: subgroup size (16) +// - tile width: _k4 (for _u64) or _k8 (for _u32) +// - number of tiles: 1 +// - for transform: +// - type / element size: _u16 or _u8 +// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) +// - tile width: subgroup size (16) +// - number of tiles: 1 + +// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: + +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); + +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m2k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m4k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} + +#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 08a82f88..d2c1e7c4 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,29 +1,17 @@ -#define OVLD __attribute__((overloadable)) - -// For all bfloat16 kernels we have tK == 16: -#define tK 16 +#include "matrix_helpers.cl" #if EMULATE_tN8 -#define mat_mul_x8 my_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_x8 emu_sub_group_bf16_bf16_matrix_mad_k16 #else #define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif #if EMULATE_tN16 -#define mat_mul_x16 my_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_x16 emu_sub_group_bf16_bf16_matrix_mad_k16 #else #define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif -float bf16_to_fp32(ushort u) -{ -#if defined(cl_intel_bfloat16_conversions) - return intel_convert_as_bfloat16_float(u); -#else - return as_float(u << 16); -#endif -} - kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, int K) { const int N = get_global_size(0); @@ -38,378 +26,18 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, C[m * N + n] = sum; } -#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) - -// These are non-block read versions. -// They work on DG2 and PVC, and on other devices when emulated. - -// SIMD8 versions: -static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) -{ - float res = acc; - - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).x), bf16_to_fp32(as_ushort2(b.s0).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 0)).y), bf16_to_fp32(as_ushort2(b.s0).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).x), bf16_to_fp32(as_ushort2(b.s1).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 1)).y), bf16_to_fp32(as_ushort2(b.s1).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).x), bf16_to_fp32(as_ushort2(b.s2).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 2)).y), bf16_to_fp32(as_ushort2(b.s2).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).x), bf16_to_fp32(as_ushort2(b.s3).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 3)).y), bf16_to_fp32(as_ushort2(b.s3).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).x), bf16_to_fp32(as_ushort2(b.s4).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 4)).y), bf16_to_fp32(as_ushort2(b.s4).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).x), bf16_to_fp32(as_ushort2(b.s5).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 5)).y), bf16_to_fp32(as_ushort2(b.s5).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).x), bf16_to_fp32(as_ushort2(b.s6).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 6)).y), bf16_to_fp32(as_ushort2(b.s6).y), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).x), bf16_to_fp32(as_ushort2(b.s7).x), res); - res = fma(bf16_to_fp32(as_ushort2(sub_group_broadcast(a, 7)).y), bf16_to_fp32(as_ushort2(b.s7).y), res); - - return res; -} - -static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int2 a, int8 b, float2 acc) -{ - float2 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - - return res; -} - -static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int4 a, int8 b, float4 acc) -{ - float4 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); - res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); - - return res; -} - -static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(int8 a, int8 b, float8 acc) -{ - float8 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); - res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); - res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); - res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); - res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); - res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); - - return res; -} - -// SIMD16 versions: -static float OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) -{ - float res = acc; - - res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); - - return res; -} - -static float2 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short2 a, int8 b, float2 acc) -{ - float2 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - - return res; -} - -static float4 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short4 a, int8 b, float4 acc) -{ - float4 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); - res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); - - return res; -} - -static float8 OVLD my_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) -{ - float8 res; - - res.s0 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s0, b, acc.s0); - res.s1 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s1, b, acc.s1); - res.s2 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s2, b, acc.s2); - res.s3 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s3, b, acc.s3); - res.s4 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s4, b, acc.s4); - res.s5 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s5, b, acc.s5); - res.s6 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s6, b, acc.s6); - res.s7 = my_sub_group_bf16_bf16_matrix_mad_k16(a.s7, b, acc.s7); - - return res; -} - -// M rows x K columns -// This is the SIMD8 version, where each work-item loads two values. -static int __load_a_row_major_bf16_k16_m1_x8(global ushort* A, int rowStart, int colStart, int stride) -{ - int ret; - - global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; - ret = intel_sub_group_block_read(A_ui + offset_ui); - - return ret; -} - -// M rows x K columns -// This is the SIMD8 version, where each work-item loads two values. -static int2 __load_a_row_major_bf16_k16_m2_x8(global ushort* A, int rowStart, int colStart, int stride) -{ - int2 ret; - - global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; - - ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - - return ret; -} - -// M rows x K columns -// This is the SIMD8 version, where each work-item loads two values. -static int4 __load_a_row_major_bf16_k16_m4_x8(global ushort* A, int rowStart, int colStart, int stride) -{ - int4 ret; - - global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; - - ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - - return ret; -} - -// M rows x K columns -// This is the SIMD8 version, where each work-item loads two values. -static int8 __load_a_row_major_bf16_k16_m8_x8(global ushort* A, int rowStart, int colStart, int stride) -{ - int8 ret; - - global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; - - ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; - - return ret; -} - -// M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. -static short __load_a_row_major_bf16_k16_m1_x16(global ushort* A, int rowStart, int colStart, int stride) -{ - ushort ret; - - int offset = rowStart * stride + colStart; - ret = intel_sub_group_block_read_us(A + offset); - - return as_short(ret); -} - -// M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. -static short2 __load_a_row_major_bf16_k16_m2_x16(global ushort* A, int rowStart, int colStart, int stride) -{ - ushort2 ret; - - int offset = rowStart * stride + colStart; - ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; - - return as_short2(ret); -} - -// M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. -static short4 __load_a_row_major_bf16_k16_m4_x16(global ushort* A, int rowStart, int colStart, int stride) -{ - ushort4 ret; - - int offset = rowStart * stride + colStart; - ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; - - return as_short4(ret); -} - -// M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. -static short8 __load_a_row_major_bf16_k16_m8_x16(global ushort* A, int rowStart, int colStart, int stride) -{ - ushort8 ret; - - int offset = rowStart * stride + colStart; - ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s3 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s4 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s5 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s6 = intel_sub_group_block_read_us(A + offset); offset += stride; - ret.s7 = intel_sub_group_block_read_us(A + offset); offset += stride; - - return as_short8(ret); -} - -// K rows x N columns: -// Each work-item loads K values and converts to VNNI. -// Stride is in units of elements. -static int8 __load_b_row_major_bf16_k16(global ushort* B, int rowStart, int colStart, int stride) -{ - int8 ret; - - int offset = rowStart * stride + colStart; - - ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row2 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row3 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row4 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row5 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row6 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row7 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row8 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row9 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row10 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row11 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row12 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row13 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row14 = intel_sub_group_block_read_us(B + offset); offset += stride; - ushort row15 = intel_sub_group_block_read_us(B + offset); offset += stride; - - ret.s0 = as_int((ushort2)(row0, row1 )); - ret.s1 = as_int((ushort2)(row2, row3 )); - ret.s2 = as_int((ushort2)(row4, row5 )); - ret.s3 = as_int((ushort2)(row6, row7 )); - ret.s4 = as_int((ushort2)(row8, row9 )); - ret.s5 = as_int((ushort2)(row10, row11)); - ret.s6 = as_int((ushort2)(row12, row13)); - ret.s7 = as_int((ushort2)(row14, row15)); - - return ret; -} - -// K rows x N columns: -// Each work-item loads K values that has already been converted to VNNI. -// Stride is in units of elements. -static int8 __load_b_vnni_bf16_k16(global ushort* B, int rowStart, int colStart, int stride) -{ - int8 ret; - - global uint* B_ui = (global uint*)B; - int offset_ui = rowStart / 2 * stride + colStart; - - ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; - - return ret; -} - -static void __store_c_row_major_fp32_m1(global float* C, float v, int rowStart, int colStart, int stride) -{ - global uint* C_ui = (global uint*)C; - uint v_ui = as_uint(v); - - int offset = rowStart * stride + colStart; - - intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; -} - -static void __store_c_row_major_fp32_m2(global float* C, float2 v, int rowStart, int colStart, int stride) -{ - global uint* C_ui = (global uint*)C; - uint2 v_ui = as_uint2(v); - - int offset = rowStart * stride + colStart; - - intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; -} - -static void __store_c_row_major_fp32_m4(global float* C, float4 v, int rowStart, int colStart, int stride) -{ - global uint* C_ui = (global uint*)C; - uint4 v_ui = as_uint4(v); - - int offset = rowStart * stride + colStart; - - intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; -} +// For all bfloat16 kernels tK == 16: +#define tK 16 -static void __store_c_row_major_fp32_m8(global float* C, float8 v, int rowStart, int colStart, int stride) -{ - global uint* C_ui = (global uint*)C; - uint8 v_ui = as_uint8(v); - - int offset = rowStart * stride + colStart; - - intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; - intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; -} +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #if HAS_SIMD8 -// For SIMD8 kernels, tN == 8: +// For all SIMD8 kernels tN == 8: #define tN 8 +// rowmajor kernels: + __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { @@ -420,12 +48,12 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob float sum = 0; for (int k = 0; k < K; k += tK) { - int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m1(C, sum, m, n, N); + store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -438,12 +66,12 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob float2 sum = 0; for (int k = 0; k < K; k += tK) { - int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m2(C, sum, m, n, N); + store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -456,12 +84,12 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob float4 sum = 0; for (int k = 0; k < K; k += tK) { - int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m4(C, sum, m, n, N); + store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -474,12 +102,12 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob float8 sum = 0; for (int k = 0; k < K; k += tK) { - int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); + int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m8(C, sum, m, n, N); + store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -503,12 +131,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x1(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -520,7 +148,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x1(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -549,12 +177,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_1x2(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -566,7 +194,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_1x2(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } #undef MM @@ -594,12 +222,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x2(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -611,7 +239,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x2(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -640,12 +268,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x2(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -657,7 +285,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x2(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -686,12 +314,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x4(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -703,7 +331,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x4(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -732,12 +360,12 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x4(global float* C, global ushor for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_row_major_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -749,7 +377,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x4(global float* C, global ushor for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -757,91 +385,7 @@ kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x4(global float* C, global ushor #undef NN } -#undef tN - -#endif // HAS_SIMD8 - -// For SIMD16 kernels, tN == 16: -#define tN 16 - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) -{ - const int tM = 1; - const int N = get_global_size(0); - const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); - - float sum = 0; - for (int k = 0; k < K; k += tK) { - short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); - } - - __store_c_row_major_fp32_m1(C, sum, m, n, N); -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) -{ - const int tM = 2; - const int N = get_global_size(0); - const int m = get_group_id(1) * tK; - const int n = get_group_id(0) * get_local_size(0); - - float2 sum = 0; - for (int k = 0; k < K; k += tK) { - short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); - } - - __store_c_row_major_fp32_m2(C, sum, m, n, N); -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) -{ - const int tM = 4; - const int N = get_global_size(0); - const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); - - float4 sum = 0; - for (int k = 0; k < K; k += tK) { - short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); - } - - __store_c_row_major_fp32_m4(C, sum, m, n, N); -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) -{ - const int tM = 8; - const int N = get_global_size(0); - const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); - - float8 sum = 0; - for (int k = 0; k < K; k += tK) { - short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); - int8 bData = __load_b_row_major_bf16_k16(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); - } - - __store_c_row_major_fp32_m8(C, sum, m, n, N); -} - -#undef tN - -#if HAS_SIMD8 - -// For SIMD8 kernels, tN == 8: -#define tN 8 +// vnni kernels: __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -853,12 +397,12 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u float sum = 0; for (int k = 0; k < K; k += tK) { - int aData = __load_a_row_major_bf16_k16_m1_x8(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m1(C, sum, m, n, N); + store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -871,12 +415,12 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u float2 sum = 0; for (int k = 0; k < K; k += tK) { - int2 aData = __load_a_row_major_bf16_k16_m2_x8(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m2(C, sum, m, n, N); + store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -889,12 +433,12 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u float4 sum = 0; for (int k = 0; k < K; k += tK) { - int4 aData = __load_a_row_major_bf16_k16_m4_x8(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m4(C, sum, m, n, N); + store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -907,12 +451,12 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u float8 sum = 0; for (int k = 0; k < K; k += tK) { - int8 aData = __load_a_row_major_bf16_k16_m8_x8(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); sum = mat_mul_x8(aData, bData, sum); } - __store_c_row_major_fp32_m8(C, sum, m, n, N); + store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) @@ -936,12 +480,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x1(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -953,7 +497,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x1(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -982,12 +526,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_1x2(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -999,7 +543,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_1x2(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -1028,12 +572,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x2(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -1045,7 +589,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x2(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -1074,12 +618,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x2(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -1091,7 +635,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x2(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -1120,12 +664,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x4(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -1137,7 +681,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x4(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -1166,12 +710,12 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x4(global float* C, global ushort* A for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { - aData[mm] = __load_a_row_major_bf16_k16_m8_x8(A, m + mm * tM, k, K); + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); } int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = __load_b_vnni_bf16_k16(B, k, n + nn * tN, N); + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } for (int mm = 0; mm < MM; mm++) { @@ -1183,7 +727,7 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x4(global float* C, global ushort* A for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - __store_c_row_major_fp32_m8(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -1191,190 +735,163 @@ kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x4(global float* C, global ushort* A #undef NN } -#undef tN +#undef tN // for SIMD8 kernels #endif // HAS_SIMD8 -// For SIMD16 kernels, tN == 16: +// For all SIMD16 kernels tN == 16: #define tN 16 +// rowmajor krenels: + __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * tN; + const int n = get_group_id(0) * get_local_size(0); float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = __load_a_row_major_bf16_k16_m1_x16(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); } - __store_c_row_major_fp32_m1(C, sum, m, n, N); + store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; const int N = get_global_size(0); - const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * tN; + const int m = get_group_id(1) * tK; + const int n = get_group_id(0) * get_local_size(0); float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = __load_a_row_major_bf16_k16_m2_x16(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); } - __store_c_row_major_fp32_m2(C, sum, m, n, N); + store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * tN; + const int n = get_group_id(0) * get_local_size(0); float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = __load_a_row_major_bf16_k16_m4_x16(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); } - __store_c_row_major_fp32_m4(C, sum, m, n, N); + store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) +kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * tN; + const int n = get_group_id(0) * get_local_size(0); float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = __load_a_row_major_bf16_k16_m8_x16(A, m, k, K); - int8 bData = __load_b_vnni_bf16_k16(B, k, n, N); + short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); sum = mat_mul_x16(aData, bData, sum); } - __store_c_row_major_fp32_m8(C, sum, m, n, N); + store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } -#undef tN +// vnni kernels: -#ifdef cl_intel_subgroup_extended_block_read +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 1; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; -// All of the block read kernels are SIMD16 kernels, tN == 16: -#define tN 16 + float sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } -// Note for 2D block reads: -// - the tile width and height is encoded into the function name. -// - base_address is the byte address. Must be 64B aligned. -// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. -// - height is the height of the entire matrix, or equivalently the number of rows. -// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. -// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. - -// Built-in functions are: - -// #ifdef cl_intel_subgroup_extended_block_read -// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); -// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); -// #endif //defined(cl_intel_subgroup_extended_block_read) - - -// For intrinsics, the pattern is: -// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat -// - operation (optional): _transpose or _transform -// - for no transpose or transform: -// - type / elements size: _u8 or _u16 or _u32 or _u64 -// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 -// - tile width: _k64 or _k32 or _k16 or _k8 -// - number of tiles: _v2 or _v1 -// - for transpose: -// - type / element size: _u64 or _u32 -// - number of tile rows: subgroup size (16) -// - tile width: _k4 (for _u64) or _k8 (for _u32) -// - number of tiles: 1 -// - for transform: -// - type / element size: _u16 or _u8 -// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) -// - tile width: subgroup size (16) -// - number of tiles: 1 - -// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: - -ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); - -ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } -uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { - return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} + const int tM = 2; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; -void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_subgroup_block_write_u32_m2k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } -void intel_subgroup_block_write_u32_m4k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { - __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); + const int tM = 4; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } -void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) + +__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); + const int tM = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); + int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + sum = mat_mul_x16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } +#ifdef cl_intel_subgroup_extended_block_read + __attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { @@ -1527,11 +1044,10 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } -#undef tN +#undef tN // for SIMD16 kernels #endif // cl_intel_subgroup_extended_block_read #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) -#undef OVLD #undef tK From feb106481446c37b528fc5cf5acd50ee93be2b15 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Jan 2024 18:37:55 -0800 Subject: [PATCH 18/99] switch to timing using event profiling --- samples/99_matrixexperiments/main.cpp | 75 +++++++++++++++++++-------- 1 file changed, 53 insertions(+), 22 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index bf1bfb53..c71180d8 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -24,6 +24,7 @@ bool identityData = false; bool fixedData = false; bool validate = false; bool emulate = false; +bool wallclock = false; int testIterations = 16; float threshold = 0.01f; @@ -149,6 +150,13 @@ void check_results( } } +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9; +} + static void go_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, @@ -167,12 +175,15 @@ static void go_naive( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -207,12 +218,15 @@ static void go_dpas_rowmajor( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -252,12 +266,15 @@ static void go_dpas_rowmajor_tiled( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N/NN, M/tM/MM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -297,12 +314,15 @@ static void go_dpas_vnni( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -344,12 +364,15 @@ static void go_dpas_vnni_tiled( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N/NN, M/tM/MM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -387,12 +410,15 @@ static void go_dpas_blockread_rowmajor( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -430,12 +456,15 @@ static void go_dpas_blockread_vnni( float best = 999.0f; for (int test = 0; test < testIterations; test++) { + cl::Event event; auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, cl::NDRange{N, M/tM}); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); queue.finish(); auto end = test_clock::now(); - std::chrono::duration elapsed_seconds = end - start; - best = std::min(best, elapsed_seconds.count()); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); } auto gops = 2.0 * M * N * K / best / 1e9; printf("Best in %f seconds (%f gops)\n", best, gops); @@ -473,6 +502,7 @@ int main(int argc, char** argv) op.add("", "identity", "Use Identity Data", &identityData); op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); bool printUsage = false; try { @@ -525,11 +555,12 @@ int main(int argc, char** argv) printf("\tTest Iterations: %d\n", testIterations); printf("\tValidating data?: %s\n", validate ? "true" : "false"); printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); cl::Context context{device}; - cl::CommandQueue queue{context, device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; printf("Reading program source from file: %s\n", fileName.c_str() ); std::string kernelString = readStringFromFile(fileName.c_str()); From 4e89026b01b922e0f7164f35aa2f11c13f2c6d6d Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 15 Jan 2024 12:10:30 -0800 Subject: [PATCH 19/99] more refactorization and simplification Now have tiled implementations for SIMD16 as well. --- samples/99_matrixexperiments/CMakeLists.txt | 2 +- samples/99_matrixexperiments/main.cpp | 118 +-- .../matrix_kernel_tiled.cl | 182 +++++ .../99_matrixexperiments/matrix_kernels.cl | 727 +++--------------- 4 files changed, 369 insertions(+), 660 deletions(-) create mode 100644 samples/99_matrixexperiments/matrix_kernel_tiled.cl diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt index 456a8cb1..86599fbf 100644 --- a/samples/99_matrixexperiments/CMakeLists.txt +++ b/samples/99_matrixexperiments/CMakeLists.txt @@ -8,4 +8,4 @@ add_opencl_sample( TARGET matrixexperiments VERSION 120 SOURCES main.cpp - KERNELS matrix_helpers.cl matrix_kernels.cl) + KERNELS matrix_helpers.cl matrix_kernels.cl matrix_kernel_tiled.cl) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index c71180d8..c8e7e00c 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -154,7 +154,7 @@ static float hw_time(cl::Event& event) { auto ns = event.getProfilingInfo() - event.getProfilingInfo(); - return ns / 1e9; + return ns / 1e9f; } static void go_naive( @@ -166,34 +166,38 @@ static void go_naive( printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); cl::Kernel kernel{program, "bfloat16_naive"}; - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } } } @@ -210,7 +214,9 @@ static void go_dpas_rowmajor( kernelName += "_m" + std::to_string(tM); kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -238,8 +244,6 @@ static void go_dpas_rowmajor( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -258,7 +262,13 @@ static void go_dpas_rowmajor_tiled( kernelName += "_" + std::to_string(MM); kernelName += "x" + std::to_string(NN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -286,8 +296,6 @@ static void go_dpas_rowmajor_tiled( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -304,7 +312,9 @@ static void go_dpas_vnni( kernelName += "_m" + std::to_string(tM); kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -334,8 +344,6 @@ static void go_dpas_vnni( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -354,7 +362,13 @@ static void go_dpas_vnni_tiled( kernelName += "_" + std::to_string(MM); kernelName += "x" + std::to_string(NN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -384,8 +398,6 @@ static void go_dpas_vnni_tiled( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -402,7 +414,9 @@ static void go_dpas_blockread_rowmajor( kernelName += "_m" + std::to_string(tM); kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -430,8 +444,6 @@ static void go_dpas_blockread_rowmajor( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -448,7 +460,9 @@ static void go_dpas_blockread_vnni( kernelName += "_m" + std::to_string(tM); kernelName += "_n" + std::to_string(tN); cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel()) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -476,8 +490,6 @@ static void go_dpas_blockread_vnni( check_results(M, N, C_check, C_ref); printf(" done!\n"); } - } else { - printf("unsupported.\n"); } } @@ -637,11 +649,25 @@ int main(int argc, char** argv) go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl new file mode 100644 index 00000000..9bec204a --- /dev/null +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -0,0 +1,182 @@ +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#if HAS_SIMD8 + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + int8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } +} + +#endif // HAS_SIMD8 + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } +} diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index d2c1e7c4..03c2dcd8 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -1,15 +1,15 @@ #include "matrix_helpers.cl" #if EMULATE_tN8 -#define mat_mul_x8 emu_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_sg8 emu_sub_group_bf16_bf16_matrix_mad_k16 #else -#define mat_mul_x8 intel_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_sg8 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif #if EMULATE_tN16 -#define mat_mul_x16 emu_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_sg16 emu_sub_group_bf16_bf16_matrix_mad_k16 #else -#define mat_mul_x16 intel_sub_group_bf16_bf16_matrix_mad_k16 +#define mat_mul_sg16 intel_sub_group_bf16_bf16_matrix_mad_k16 #endif kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, int K) @@ -33,15 +33,13 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, #if HAS_SIMD8 -// For all SIMD8 kernels tN == 8: -#define tN 8 - // rowmajor kernels: -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -50,16 +48,17 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob for (int k = 0; k < K; k += tK) { int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -68,16 +67,17 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob for (int k = 0; k < K; k += tK) { int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -86,16 +86,17 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob for (int k = 0; k < K; k += tK) { int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -104,293 +105,19 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob for (int k = 0; k < K; k += tK) { int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x1(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 1 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_1x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 1 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 4 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_2x4(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 4 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_rowmajor_tiled_m8_n8_4x4(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 4 - #define NN 4 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - // vnni kernels: -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -399,16 +126,17 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u for (int k = 0; k < K; k += tK) { int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -417,16 +145,17 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u for (int k = 0; k < K; k += tK) { int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -435,16 +164,17 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u for (int k = 0; k < K; k += tK) { int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 8; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -453,301 +183,21 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u for (int k = 0; k < K; k += tK) { int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x8(aData, bData, sum); + sum = mat_mul_sg8(aData, bData, sum); } store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x1(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 1 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_1x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 1 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x2(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 4 - #define NN 2 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_2x4(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 2 - #define NN 4 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) -kernel void bfloat16_dpas_vnni_tiled_m8_n8_4x4(global float* C, global ushort* A, global ushort* B, int K) -{ - #define MM 4 - #define NN 4 - - const int tM = 8; - const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; - const int n = get_group_id(0) * tN * NN; - - float8 sum[MM][NN]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; - } - } - - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); - } - - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_x8(aData[mm], bData[nn], sum[mm][nn]); - } - } - } - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); - } - } - - #undef MM - #undef NN -} - -#undef tN // for SIMD8 kernels - #endif // HAS_SIMD8 -// For all SIMD16 kernels tN == 16: -#define tN 16 - // rowmajor krenels: -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * get_local_size(0); @@ -756,16 +206,17 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo for (int k = 0; k < K; k += tK) { short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tK; const int n = get_group_id(0) * get_local_size(0); @@ -774,16 +225,17 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo for (int k = 0; k < K; k += tK) { short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * get_local_size(0); @@ -792,16 +244,17 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo for (int k = 0; k < K; k += tK) { short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * get_local_size(0); @@ -810,7 +263,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo for (int k = 0; k < K; k += tK) { short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); @@ -818,10 +271,11 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo // vnni kernels: -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -830,16 +284,17 @@ kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global for (int k = 0; k < K; k += tK) { short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -848,16 +303,17 @@ kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global for (int k = 0; k < K; k += tK) { short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -866,16 +322,17 @@ kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global for (int k = 0; k < K; k += tK) { short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; const int n = get_group_id(0) * tN; @@ -884,7 +341,7 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global for (int k = 0; k < K; k += tK) { short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); @@ -892,10 +349,11 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global #ifdef cl_intel_subgroup_extended_block_read -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 16; const int M = get_global_size(1); const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -905,16 +363,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho for (int k = 0; k < K; k += tK) { short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -924,16 +383,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho for (int k = 0; k < K; k += tK) { short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -943,16 +403,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho for (int k = 0; k < K; k += tK) { short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -962,16 +423,17 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho for (int k = 0; k < K; k += tK) { short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 1; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -981,16 +443,17 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* for (int k = 0; k < K; k += tK) { short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 2; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -1000,16 +463,17 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* for (int k = 0; k < K; k += tK) { short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 4; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -1019,16 +483,17 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* for (int k = 0; k < K; k += tK) { short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } -__attribute__((intel_reqd_sub_group_size(tN))) __attribute__((reqd_work_group_size(tN, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; + const int tN = 16; const int M = get_global_size(1) * tM; const int N = get_global_size(0); const int m = get_group_id(1) * tM; @@ -1038,16 +503,52 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* for (int k = 0; k < K; k += tK) { short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); - sum = mat_mul_x16(aData, bData, sum); + sum = mat_mul_sg16(aData, bData, sum); } intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } -#undef tN // for SIMD16 kernels - #endif // cl_intel_subgroup_extended_block_read +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #undef tK From c7edcd666265cc26d669389c211c1775b853ed29 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 15 Jan 2024 13:44:06 -0800 Subject: [PATCH 20/99] add tiled block read kernels for PVC --- .../matrix_kernel_tiled.cl | 86 +++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 9bec204a..cd04b299 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -180,3 +180,89 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } } + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = get_group_id(1) * tM * MM; + const int n = get_group_id(0) * tN * NN; + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + for (int k = 0; k < K; k += tK) { + short8 aData[MM]; + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } + + int8 bData[NN]; + for (int nn = 0; nn < NN; nn++) { + bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + } + } + } + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + } + } +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) From a43376945037923f06e4e1e6df33929da636abb5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 15 Jan 2024 14:03:26 -0800 Subject: [PATCH 21/99] fix block read tiled kernels and execute them --- samples/99_matrixexperiments/main.cpp | 136 +++++++++++++++++- .../matrix_kernel_tiled.cl | 12 +- 2 files changed, 140 insertions(+), 8 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index c8e7e00c..42b3cbef 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -222,6 +222,8 @@ static void go_dpas_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + float best = 999.0f; for (int test = 0; test < testIterations; test++) { cl::Event event; @@ -262,18 +264,20 @@ static void go_dpas_rowmajor_tiled( kernelName += "_" + std::to_string(MM); kernelName += "x" + std::to_string(NN); cl::Kernel kernel{program, kernelName.c_str()}; - if (tM * MM > M) { + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { printf("M is too small.\n"); } else if (tN * NN > N) { printf("N is too small.\n"); - } else if (kernel() == nullptr) { - printf("unsupported.\n"); } else { kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + float best = 999.0f; for (int test = 0; test < testIterations; test++) { cl::Event event; @@ -422,6 +426,8 @@ static void go_dpas_blockread_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + float best = 999.0f; for (int test = 0; test < testIterations; test++) { cl::Event event; @@ -447,6 +453,60 @@ static void go_dpas_blockread_rowmajor( } } +template +static void go_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + template static void go_dpas_blockread_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, @@ -468,6 +528,8 @@ static void go_dpas_blockread_vnni( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + float best = 999.0f; for (int test = 0; test < testIterations; test++) { cl::Event event; @@ -493,6 +555,60 @@ static void go_dpas_blockread_vnni( } } +template +static void go_dpas_blockread_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + int main(int argc, char** argv) { int platformIndex = 0; @@ -673,11 +789,25 @@ int main(int argc, char** argv) go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + printf("Done.\n"); return 0; diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index cd04b299..b363b142 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -184,10 +184,11 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; + const int M = get_global_size(1) * tM; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -207,7 +208,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); } for (int mm = 0; mm < MM; mm++) { @@ -225,10 +226,11 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; + const int M = get_global_size(1) * tM; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -248,7 +250,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { - bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); + bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2))); } for (int mm = 0; mm < MM; mm++) { @@ -265,4 +267,4 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } -#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) +#endif // cl_intel_subgroup_extended_block_read From d4eb405b631ca43aca6a42c3036766359ecb784c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 16 Jan 2024 21:43:27 -0800 Subject: [PATCH 22/99] fix typo affecting one of the SIMD16 kernels --- samples/99_matrixexperiments/matrix_kernels.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 03c2dcd8..a310db63 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -218,7 +218,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo const int tM = 2; const int tN = 16; const int N = get_global_size(0); - const int m = get_group_id(1) * tK; + const int m = get_group_id(1) * tM; const int n = get_group_id(0) * get_local_size(0); float2 sum = 0; From b6be2d4611e39cf18d31a177cf4e9f0e43ee59b7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 17 Jan 2024 10:30:40 -0800 Subject: [PATCH 23/99] fix a few more bugs and improve validation testing --- samples/99_matrixexperiments/main.cpp | 18 +++++++++--------- samples/99_matrixexperiments/matrix_helpers.cl | 8 ++++---- .../matrix_kernel_tiled.cl | 8 ++++---- samples/99_matrixexperiments/matrix_kernels.cl | 16 ++++++++-------- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 42b3cbef..044739e3 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -174,7 +174,7 @@ static void go_naive( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -222,7 +222,7 @@ static void go_dpas_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -276,7 +276,7 @@ static void go_dpas_rowmajor_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -324,7 +324,7 @@ static void go_dpas_vnni( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -378,7 +378,7 @@ static void go_dpas_vnni_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -426,7 +426,7 @@ static void go_dpas_blockread_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -480,7 +480,7 @@ static void go_dpas_blockread_rowmajor_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -528,7 +528,7 @@ static void go_dpas_blockread_vnni( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -582,7 +582,7 @@ static void go_dpas_blockread_vnni_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size()); + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); float best = 999.0f; for (int test = 0; test < testIterations; test++) { diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ef68fa44..04a63c53 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -464,19 +464,19 @@ uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -void intel_subgroup_block_write_u32_m1k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m2k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) { __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m4k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) { __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m8k16v1(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) { __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index b363b142..a01e0f2e 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -188,7 +188,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN { const int tM = 8; const int tN = 16; - const int M = get_global_size(1) * tM; + const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -220,7 +220,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); } } } @@ -230,7 +230,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl { const int tM = 8; const int tN = 16; - const int M = get_global_size(1) * tM; + const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; const int m = get_group_id(1) * tM * MM; const int n = get_group_id(0) * tN * NN; @@ -262,7 +262,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); } } } diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index a310db63..efc8f4fa 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -366,7 +366,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -386,7 +386,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -406,7 +406,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -426,7 +426,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -446,7 +446,7 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m1k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -466,7 +466,7 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m2k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -486,7 +486,7 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m4k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -506,7 +506,7 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } - intel_subgroup_block_write_u32_m8k16v1(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } #endif // cl_intel_subgroup_extended_block_read From d76df7e82f8a2cee9878340891b4d0886a4ec9f9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 17 Jan 2024 10:58:36 -0800 Subject: [PATCH 24/99] add support for a larger A matrix block read --- .../99_matrixexperiments/matrix_helpers.cl | 21 +++++++++------- .../matrix_kernel_tiled.cl | 24 +++++++++++++++---- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 04a63c53..5e24f682 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -430,10 +430,11 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: -ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -442,22 +443,26 @@ void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int wid void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); -ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) { diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a01e0f2e..a3b56402 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -202,8 +202,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK) { short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + if (MM % 2 == 0) { + for (int mm = 0; mm < MM; mm += 2) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[mm + 0] = aTemp.lo; + aData[mm + 1] = aTemp.hi; + } + } else { + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } } int8 bData[NN]; @@ -244,8 +252,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK) { short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + if (MM % 2 == 0) { + for (int mm = 0; mm < MM; mm += 2) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[mm + 0] = aTemp.lo; + aData[mm + 1] = aTemp.hi; + } + } else { + for (int mm = 0; mm < MM; mm++) { + aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + } } int8 bData[NN]; From 756d2e9decaf63edff9ed23044d5ea27d60878ba Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 17 Jan 2024 14:35:03 -0800 Subject: [PATCH 25/99] switch the tiled dpas order We want to prioritize reuse of the A matrix to make best use of read suppression buffers. --- .../matrix_kernel_tiled.cl | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a3b56402..87fa6588 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -49,8 +49,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -83,8 +83,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } @@ -126,8 +126,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -167,8 +167,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -219,8 +219,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } @@ -269,8 +269,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2))); } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } From 4caea7b465aaa856eb9b52c891d81b096e79c76c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 18:52:02 -0800 Subject: [PATCH 26/99] temporarily disable the large a matrix block load This is not working (silently failing) with some recent drivers, so disable it for now. Ideally we will be able to reenable it shortly. --- .../matrix_kernel_tiled.cl | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 87fa6588..58aba827 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -202,17 +202,17 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK) { short8 aData[MM]; - if (MM % 2 == 0) { - for (int mm = 0; mm < MM; mm += 2) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - aData[mm + 0] = aTemp.lo; - aData[mm + 1] = aTemp.hi; - } - } else { + //if (MM % 2 == 0) { + // for (int mm = 0; mm < MM; mm += 2) { + // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + // aData[mm + 0] = aTemp.lo; + // aData[mm + 1] = aTemp.hi; + // } + //} else { for (int mm = 0; mm < MM; mm++) { aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); } - } + //} int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { @@ -252,17 +252,17 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK) { short8 aData[MM]; - if (MM % 2 == 0) { - for (int mm = 0; mm < MM; mm += 2) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - aData[mm + 0] = aTemp.lo; - aData[mm + 1] = aTemp.hi; - } - } else { + //if (MM % 2 == 0) { + // for (int mm = 0; mm < MM; mm += 2) { + // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + // aData[mm + 0] = aTemp.lo; + // aData[mm + 1] = aTemp.hi; + // } + //} else { for (int mm = 0; mm < MM; mm++) { aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); } - } + //} int8 bData[NN]; for (int nn = 0; nn < NN; nn++) { From 0fb3d6671f01f497cbbda946b190e72b6f7a84f8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 19:07:15 -0800 Subject: [PATCH 27/99] add support for launching more than one subgroup per work group This should enable better cache reuse across subgroups. --- .../99_matrixexperiments/matrix_helpers.cl | 7 +++++ .../matrix_kernel_tiled.cl | 29 +++++++++++-------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 5e24f682..ce36217b 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -9,6 +9,13 @@ float bf16_to_fp32(ushort u) #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) +inline int compute_m(const int num_sgs, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs; + const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start; + return m_index * tM * MM; +} + // Emulated SIMD8 dpas: __attribute__((overloadable)) float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 58aba827..30462224 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -13,15 +13,20 @@ #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) +#if !defined(SGS_PER_WG) +// Launch four subgroups per work-group, to maximize cache reuse. +#define SGS_PER_WG 4 +#endif + #if HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; @@ -56,13 +61,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; @@ -99,13 +104,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* #endif // HAS_SIMD8 -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; @@ -140,13 +145,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; @@ -183,14 +188,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; @@ -233,14 +238,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = get_group_id(1) * tM * MM; + const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; float8 sum[MM][NN]; From d09b982cdc0a9dbf2ed34cf25e92a41c02e770fa Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 21:15:46 -0800 Subject: [PATCH 28/99] add support for split barriers This may also be helpful to keep subgroups running approximately together, which could also improve cache utilization. --- .../matrix_kernel_tiled.cl | 53 +++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 30462224..cb3c6ca3 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -10,6 +10,17 @@ #error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." #endif +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#if !defined(cl_intel_split_work_group_barrier) +#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" +#endif +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) @@ -36,6 +47,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -52,8 +65,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -77,6 +95,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { int8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -93,8 +113,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -120,6 +145,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -136,8 +163,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -161,6 +193,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; for (int mm = 0; mm < MM; mm++) { @@ -177,8 +211,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); @@ -205,6 +244,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; //if (MM % 2 == 0) { @@ -229,8 +270,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); @@ -255,6 +301,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } } + split_barrier_arrive(); + for (int k = 0; k < K; k += tK) { short8 aData[MM]; //if (MM % 2 == 0) { @@ -279,8 +327,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); } } + + split_barrier_wait(); + split_barrier_arrive(); } + split_barrier_wait(); + for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); From 031e076143111aa632eca6dd0b6082502b9c6535 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 22:18:43 -0800 Subject: [PATCH 29/99] add support for larger K values for some tiled kernels --- .../matrix_kernel_tiled.cl | 111 +++++++++++------- 1 file changed, 71 insertions(+), 40 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index cb3c6ca3..05c2ace8 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -10,6 +10,10 @@ #error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." #endif +#if !defined(KK) +#define KK 1 +#endif + #if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) #if !defined(cl_intel_split_work_group_barrier) #warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" @@ -49,20 +53,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + int8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -97,20 +107,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - int8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + int8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg8(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -147,20 +163,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_rowmajor_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -195,20 +217,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - for (int mm = 0; mm < MM; mm++) { - aData[mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k, K); + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = load_b_vnni_d16_k16_nx(B, k, n + nn * tN, N); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -342,3 +370,6 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } #endif // cl_intel_subgroup_extended_block_read + +#undef KK +#undef SGS_PER_WG From 16b7cda99eb30331e1b70287b7173bff61e28c2e Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 18 Jan 2024 22:19:09 -0800 Subject: [PATCH 30/99] rename tester host functions to match kernel names more closely Also, remove tK from all host function output, since it is only used internally within the kernels. --- samples/99_matrixexperiments/main.cpp | 204 +++++++++++++------------- 1 file changed, 102 insertions(+), 102 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 044739e3..a3bacda6 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -40,25 +40,25 @@ std::string makeTestName( std::string makeTestName( const std::string &func, - int tM, int tN, int tK, + int tM, int tN, size_t M, size_t N, size_t K) { std::ostringstream ret; ret << func; - ret << ""; + ret << ""; ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; return ret.str(); } std::string makeTestName( const std::string &func, - int tM, int tN, int tK, + int tM, int tN, int MM, int NN, size_t M, size_t N, size_t K) { std::ostringstream ret; ret << func; - ret << ""; + ret << ""; ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; return ret.str(); } @@ -157,7 +157,7 @@ static float hw_time(cl::Event& event) return ns / 1e9f; } -static void go_naive( +static void bfloat16_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, @@ -201,14 +201,14 @@ static void go_naive( } } -template -static void go_dpas_rowmajor( +template +static void bfloat16_dpas_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_rowmajor"; kernelName += "_m" + std::to_string(tM); @@ -249,14 +249,14 @@ static void go_dpas_rowmajor( } } -template -static void go_dpas_rowmajor_tiled( +template +static void bfloat16_dpas_rowmajor_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_rowmajor_tiled"; kernelName += "_m" + std::to_string(tM); @@ -303,14 +303,14 @@ static void go_dpas_rowmajor_tiled( } } -template -static void go_dpas_vnni( +template +static void bfloat16_dpas_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_vnni"; kernelName += "_m" + std::to_string(tM); @@ -351,14 +351,14 @@ static void go_dpas_vnni( } } -template -static void go_dpas_vnni_tiled( +template +static void bfloat16_dpas_vnni_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_vnni_tiled"; kernelName += "_m" + std::to_string(tM); @@ -405,14 +405,14 @@ static void go_dpas_vnni_tiled( } } -template -static void go_dpas_blockread_rowmajor( +template +static void bfloat16_dpas_blockread_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_rowmajor"; kernelName += "_m" + std::to_string(tM); @@ -453,14 +453,14 @@ static void go_dpas_blockread_rowmajor( } } -template -static void go_dpas_blockread_rowmajor_tiled( +template +static void bfloat16_dpas_blockread_rowmajor_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_rowmajor_tiled"; kernelName += "_m" + std::to_string(tM); @@ -507,14 +507,14 @@ static void go_dpas_blockread_rowmajor_tiled( } } -template -static void go_dpas_blockread_vnni( +template +static void bfloat16_dpas_blockread_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_vnni"; kernelName += "_m" + std::to_string(tM); @@ -555,14 +555,14 @@ static void go_dpas_blockread_vnni( } } -template -static void go_dpas_blockread_vnni_tiled( +template +static void bfloat16_dpas_blockread_vnni_tiled( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, size_t M, size_t N, size_t K, const std::vector& C_ref) { - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, tK, MM, NN, M, N, K).c_str()); fflush(stdout); + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); std::string kernelName = "bfloat16_dpas_blockread_vnni_tiled"; kernelName += "_m" + std::to_string(tM); @@ -734,79 +734,79 @@ int main(int argc, char** argv) printf("Running tests...\n"); - go_naive(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor<1, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_vnni_tiled<8, 8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_blockread_rowmajor<1, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<2, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<4, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor<8, 16, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - go_dpas_blockread_rowmajor_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - go_dpas_blockread_vnni<1, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<2, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<4, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni<8, 16, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - go_dpas_blockread_vnni_tiled<8, 16, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + + bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); printf("Done.\n"); From 83185fd35efd41b945b6c6b892e295981b0ef8e9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 19 Jan 2024 10:04:02 -0800 Subject: [PATCH 31/99] add support for more K tiles for the blockread kernels --- .../matrix_kernel_tiled.cl | 64 +++++++++---------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 05c2ace8..7874d7b2 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -274,28 +274,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - //if (MM % 2 == 0) { - // for (int mm = 0; mm < MM; mm += 2) { - // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - // aData[mm + 0] = aTemp.lo; - // aData[mm + 1] = aTemp.hi; - // } - //} else { + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); } - //} + } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k))); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } @@ -331,28 +329,26 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); - for (int k = 0; k < K; k += tK) { - short8 aData[MM]; - //if (MM % 2 == 0) { - // for (int mm = 0; mm < MM; mm += 2) { - // short16 aTemp = as_short16(intel_subgroup_block_read_u16_m16k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); - // aData[mm + 0] = aTemp.lo; - // aData[mm + 1] = aTemp.hi; - // } - //} else { + for (int k = 0; k < K; k += tK * KK) { + short8 aData[KK][MM]; + for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m + mm * tM))); + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); } - //} + } - int8 bData[NN]; - for (int nn = 0; nn < NN; nn++) { - bData[nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, k / 2))); + int8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } } - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[mm], bData[nn], sum[mm][nn]); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } } } From a24a5b06b7ebc1aa81880b1d82adfbb60212da01 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 22 Jan 2024 14:09:47 -0800 Subject: [PATCH 32/99] start to add support for loading two K tiles at once --- .../99_matrixexperiments/matrix_helpers.cl | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ce36217b..6398e12b 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -216,8 +216,30 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart return ret; } +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + int16 ret; + + global uint* A_ui = (global uint*)A; + int offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return ret; +} + // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort ret; @@ -229,7 +251,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort2 ret; @@ -242,7 +264,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort4 ret; @@ -257,7 +279,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt } // M rows x K columns -// This is the SIMD16 version, where each work-item loads one values. +// This is the SIMD16 version, where each work-item loads one value. short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort8 ret; @@ -275,6 +297,26 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt return as_short8(ret); } +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + short16 ret; + + int offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + + return ret; +} + // K rows x N columns: // Each work-item loads K values and converts to VNNI. // Stride is in units of elements. From 38d03c00e2c9a6ae20ba763a8697e7ce84a2d162 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 22 Jan 2024 14:13:49 -0800 Subject: [PATCH 33/99] fix type for emulated v2 block reads --- samples/99_matrixexperiments/matrix_helpers.cl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 6398e12b..0719fdba 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -221,7 +221,7 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart // The first tile is returned the first components of the return value, the the next tile, etc. int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) { - int16 ret; + uint16 ret; global uint* A_ui = (global uint*)A; int offset_ui = rowStart * stride / 2 + colStart / 2; @@ -235,7 +235,7 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; - return ret; + return as_int16(ret); } // M rows x K columns @@ -302,7 +302,7 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt // The first tile is returned the first components of the return value, the the next tile, etc. short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) { - short16 ret; + ushort16 ret; int offset = rowStart * stride + colStart; ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; @@ -314,7 +314,7 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - return ret; + return as_short16(ret); } // K rows x N columns: From ab84bbe442b7b721263dcd714bfa3ac56c9885c8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 23 Jan 2024 15:14:29 -0800 Subject: [PATCH 34/99] performance improvements and bugfixes for DG2 - Fixed an error in the stride computation. - Added changes to improve stateless-to-stateful compilation. - Added a wider A matrix block read to load two K tiles at a time. --- samples/99_matrixexperiments/main.cpp | 6 ++ .../99_matrixexperiments/matrix_helpers.cl | 50 ++++++------- .../matrix_kernel_tiled.cl | 72 ++++++++++++++----- .../99_matrixexperiments/matrix_kernels.cl | 6 ++ 4 files changed, 94 insertions(+), 40 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index a3bacda6..a38cf27f 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -741,6 +741,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -753,6 +754,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -765,6 +767,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -777,6 +780,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -789,6 +793,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -801,6 +806,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 0719fdba..b77ede36 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -9,6 +9,8 @@ float bf16_to_fp32(ushort u) #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) +typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); + inline int compute_m(const int num_sgs, const int tM, const int MM) { const int m_start = get_group_id(1) * num_sgs; @@ -157,7 +159,7 @@ int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart int ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret = intel_sub_group_block_read(A_ui + offset_ui); return ret; @@ -170,7 +172,7 @@ int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart int2 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -185,7 +187,7 @@ int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart int4 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -202,7 +204,7 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart int8 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 2; @@ -224,7 +226,7 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt uint16 ret; global uint* A_ui = (global uint*)A; - int offset_ui = rowStart * stride / 2 + colStart / 2; + uint offset_ui = rowStart * stride / 2 + colStart / 2; ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; @@ -244,7 +246,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta { ushort ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret = intel_sub_group_block_read_us(A + offset); return as_short(ret); @@ -256,7 +258,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt { ushort2 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -269,7 +271,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt { ushort4 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -284,7 +286,7 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt { ushort8 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ret.s0 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s1 = intel_sub_group_block_read_us(A + offset); offset += stride; ret.s2 = intel_sub_group_block_read_us(A + offset); offset += stride; @@ -304,15 +306,15 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co { ushort16 ret; - int offset = rowStart * stride + colStart; - ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; - ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride / 2; + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; return as_short16(ret); } @@ -324,7 +326,7 @@ int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, in { int8 ret; - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; ushort row0 = intel_sub_group_block_read_us(B + offset); offset += stride; ushort row1 = intel_sub_group_block_read_us(B + offset); offset += stride; @@ -363,7 +365,7 @@ int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int st int8 ret; global uint* B_ui = (global uint*)B; - int offset_ui = rowStart / 2 * stride + colStart; + uint offset_ui = rowStart / 2 * stride + colStart; ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; @@ -382,7 +384,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col global uint* C_ui = (global uint*)C; uint v_ui = as_uint(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; } @@ -392,7 +394,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint2 v_ui = as_uint2(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; @@ -403,7 +405,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint4 v_ui = as_uint4(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; @@ -416,7 +418,7 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co global uint* C_ui = (global uint*)C; uint8 v_ui = as_uint8(v); - int offset = rowStart * stride + colStart; + uint offset = rowStart * stride + colStart; intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 7874d7b2..c9322350 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -36,7 +36,7 @@ #if HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; @@ -55,9 +55,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { int8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } } @@ -90,7 +100,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; @@ -109,9 +119,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { int8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } } } @@ -146,7 +166,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* #endif // HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; @@ -165,9 +185,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } } @@ -200,7 +230,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; @@ -219,9 +249,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } } } diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index efc8f4fa..f869a3bb 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -513,6 +513,12 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* // Tiled matrix multiplication kernels, generated from a template: +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + #define MM 2 #define NN 1 #include "matrix_kernel_tiled.cl" From 4ae2d9505d7f3fbf0a31a83b3277848c1536f5f7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 24 Jan 2024 12:12:13 -0800 Subject: [PATCH 35/99] add support for wide K block reads --- .../matrix_kernel_tiled.cl | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index c9322350..47a5c958 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -316,9 +316,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } } } @@ -371,9 +381,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { short8 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } } } @@ -406,6 +426,3 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } #endif // cl_intel_subgroup_extended_block_read - -#undef KK -#undef SGS_PER_WG From aa95c5ee452a394aa2566eff8a2b314b42024b29 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 24 Jan 2024 13:56:01 -0800 Subject: [PATCH 36/99] minor improvements to tester program --- samples/99_matrixexperiments/main.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index a38cf27f..55951e94 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -650,16 +650,27 @@ int main(int argc, char** argv) std::vector platforms; cl::Platform::get(&platforms); + if (platformIndex >= platforms.size()) { + printf("Requested platform index is %d, but only %zu platforms were found.\n", + platformIndex, platforms.size()); + return -1; + } printf("Running on platform: %s\n", platforms[platformIndex].getInfo().c_str() ); std::vector devices; platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } cl::Device& device = devices[deviceIndex]; - printf("Running on device: %s\n", - device.getInfo().c_str() ); + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); auto minSubGroupSize = findMinSubGroupSize(device); From 40d0d7a98b40d56a9be5e31d63c98bd1865367a0 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 24 Jan 2024 20:54:13 -0800 Subject: [PATCH 37/99] try a larger B matrix block read for the VNNI kernel --- samples/99_matrixexperiments/matrix_helpers.cl | 7 ++++++- .../99_matrixexperiments/matrix_kernel_tiled.cl | 16 +++++++++++++--- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index b77ede36..e4ad091c 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -487,7 +487,8 @@ ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); @@ -519,6 +520,10 @@ uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int { return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 47a5c958..98ddfa91 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -398,9 +398,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[kk + 0][nn] = bTemp.lo; + bData[kk + 1][nn] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } } } From 07311ec5875102e4b2356c74767a7013377d674b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 7 Feb 2024 22:21:19 -0800 Subject: [PATCH 38/99] add a mask argument to only run a subset of tests --- samples/99_matrixexperiments/main.cpp | 187 +++++++++++++++----------- 1 file changed, 108 insertions(+), 79 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 55951e94..e0f3a8cb 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -618,6 +618,8 @@ int main(int argc, char** argv) std::string buildOptions; size_t matrixSize = 512; + size_t mask = ~0; + { popl::OptionParser op("Supported Options"); op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); @@ -632,6 +634,7 @@ int main(int argc, char** argv) op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add("", "wallclock", "Measure Wallclock Time", &wallclock); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); bool printUsage = false; try { op.parse(argc, argv); @@ -745,85 +748,111 @@ int main(int argc, char** argv) printf("Running tests...\n"); - bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - - bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - - bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + if (mask & 0x1) { + bfloat16_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x2) { + bfloat16_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x4) { + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x8) { + bfloat16_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x10) { + bfloat16_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x20) { + bfloat16_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x80) { + bfloat16_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x100) { + bfloat16_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x200) { + bfloat16_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x800) { + bfloat16_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x1000) { + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } printf("Done.\n"); From b19fe5e4f3956f83261d7ab33b8e836002c5e5ca Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 7 Feb 2024 22:21:40 -0800 Subject: [PATCH 39/99] initial support for prefetching --- .../99_matrixexperiments/matrix_helpers.cl | 57 ++++++++++ .../matrix_kernel_tiled.cl | 104 ++++++++++++++++++ 2 files changed, 161 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index e4ad091c..55ebcc65 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -240,6 +240,15 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt return as_int16(ret); } +// M rows x K columns x V tiles (in the K dimension) +void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(A + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + // M rows x K columns // This is the SIMD16 version, where each work-item loads one value. short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) @@ -319,6 +328,15 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co return as_short16(ret); } +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(A + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + // K rows x N columns: // Each work-item loads K values and converts to VNNI. // Stride is in units of elements. @@ -379,6 +397,45 @@ int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int st return ret; } +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(B + offset, 1); offset += 8 * stride; + prefetch(B + offset, 1); offset += 8 * stride; +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(B + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the K dimension) +void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 98ddfa91..7fdbe284 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -44,6 +44,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -54,6 +67,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, next_k + kk * tK, n + nn * tN, N); + } + } + int8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -108,6 +134,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -118,6 +157,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, next_k + kk * tK, n + nn * tN, N); + } + } + int8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -174,6 +226,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -184,6 +249,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } + short8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -238,6 +316,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -248,6 +339,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } + short8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { From a1efab209e2f749f6e0abe9e85665a63c7ad9d2b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 7 Feb 2024 22:41:56 -0800 Subject: [PATCH 40/99] add support for more prefetching --- .../matrix_kernel_tiled.cl | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 7fdbe284..e64d4b21 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -409,6 +409,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -419,6 +432,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } + short8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -474,6 +500,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } + float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { @@ -484,6 +523,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } + short8 aData[KK][MM]; if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { From 92c90d41872455c207d88a9ecd64c258de040add Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 7 Feb 2024 23:10:25 -0800 Subject: [PATCH 41/99] add driver version output to tester --- samples/99_matrixexperiments/main.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index e0f3a8cb..db30d2a6 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -674,6 +674,8 @@ int main(int argc, char** argv) device.getInfo().c_str(), device.getInfo(), device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); auto minSubGroupSize = findMinSubGroupSize(device); From 5f29f593ae75f8c77db03e3083d82f0419182175 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 13 Feb 2024 15:09:56 -0800 Subject: [PATCH 42/99] add 8x2 tiled versions --- samples/99_matrixexperiments/main.cpp | 6 ++++++ samples/99_matrixexperiments/matrix_kernels.cl | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index db30d2a6..dddb8c7c 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -769,6 +769,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 8, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x8) { @@ -786,6 +787,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 8, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } if (mask & 0x20) { @@ -803,6 +805,7 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_rowmajor_tiled<8, 16, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x80) { @@ -820,6 +823,7 @@ int main(int argc, char** argv) bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_vnni_tiled<8, 16, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } if (mask & 0x200) { @@ -837,6 +841,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x800) { @@ -854,6 +859,7 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + bfloat16_dpas_blockread_vnni_tiled<8, 16, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index f869a3bb..29e0c193 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -555,6 +555,12 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* #undef MM #undef NN +#define MM 8 +#define NN 2 +#include "matrix_kernel_tiled.cl" +#undef MM +#undef NN + #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #undef tK From 8a0ed40c7e19a7e1ef39327275929b5abebed0ea Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 21 Feb 2024 21:27:50 -0800 Subject: [PATCH 43/99] add tf32 tester --- samples/99_matrixexperiments/main.cpp | 38 +- .../99_matrixexperiments/matrix_helpers.cl | 44 ++ .../99_matrixexperimentstf32/CMakeLists.txt | 11 + samples/99_matrixexperimentstf32/main.cpp | 593 ++++++++++++++++++ .../matrix_helpers_tf32.cl | 352 +++++++++++ .../matrix_kernel_tiled_tf32.cl | 239 +++++++ .../matrix_kernels_tf32.cl | 218 +++++++ samples/CMakeLists.txt | 3 +- 8 files changed, 1488 insertions(+), 10 deletions(-) create mode 100644 samples/99_matrixexperimentstf32/CMakeLists.txt create mode 100644 samples/99_matrixexperimentstf32/main.cpp create mode 100644 samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl create mode 100644 samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl create mode 100644 samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index dddb8c7c..23bf2d87 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -25,6 +25,7 @@ bool fixedData = false; bool validate = false; bool emulate = false; bool wallclock = false; +bool skipinit = false; int testIterations = 16; float threshold = 0.01f; @@ -174,7 +175,9 @@ static void bfloat16_naive( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -222,7 +225,9 @@ static void bfloat16_dpas_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -276,7 +281,9 @@ static void bfloat16_dpas_rowmajor_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -324,7 +331,9 @@ static void bfloat16_dpas_vnni( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -378,7 +387,9 @@ static void bfloat16_dpas_vnni_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -426,7 +437,9 @@ static void bfloat16_dpas_blockread_rowmajor( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -480,7 +493,9 @@ static void bfloat16_dpas_blockread_rowmajor_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -528,7 +543,9 @@ static void bfloat16_dpas_blockread_vnni( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -582,7 +599,9 @@ static void bfloat16_dpas_blockread_vnni_tiled( kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } float best = 999.0f; for (int test = 0; test < testIterations; test++) { @@ -633,6 +652,7 @@ int main(int argc, char** argv) op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); bool printUsage = false; diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 55ebcc65..0a19797a 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -7,6 +7,50 @@ float bf16_to_fp32(ushort u) #endif } +__attribute__((overloadable)) +float activation(float f) +{ +#if defined(ACTIVATION_RELU) + return fmax(f, 0); +#else // identity + return f; +#endif +} + +__attribute__((overloadable)) +float2 activation(float2 f) +{ + float2 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + return res; +} + +__attribute__((overloadable)) +float4 activation(float4 f) +{ + float4 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + return res; +} + +float8 activation(float8 f) +{ + float8 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + res.s4 = activation(f.s4); + res.s5 = activation(f.s5); + res.s6 = activation(f.s6); + res.s7 = activation(f.s7); + return res; +} + #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); diff --git a/samples/99_matrixexperimentstf32/CMakeLists.txt b/samples/99_matrixexperimentstf32/CMakeLists.txt new file mode 100644 index 00000000..de636108 --- /dev/null +++ b/samples/99_matrixexperimentstf32/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2019-2024 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 99 + TARGET matrixexperimentstf32 + VERSION 120 + SOURCES main.cpp + KERNELS matrix_helpers_tf32.cl matrix_kernels_tf32.cl matrix_kernel_tiled_tf32.cl) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/99_matrixexperimentstf32/main.cpp new file mode 100644 index 00000000..3ec2c7cb --- /dev/null +++ b/samples/99_matrixexperimentstf32/main.cpp @@ -0,0 +1,593 @@ +/* +// Copyright (c) 2019-2024 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +float to_tf32(float f) +{ + union { + uint32_t u; + float f; + } value; + + value.f = f; + value.u &= 0xFFFFE000; + + // Be careful not to convert NAN to INF: + if (std::isnan(f) && !std::isnan(value.f)) { + value.u |= 0x00002000; + } + + return value.f; +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return to_tf32(1.0f); }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = to_tf32(static_cast(r) + static_cast(c) / 64.0f); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_real_distribution dist(-1.0, 1.0); + std::generate(std::begin(M), std::end(M), [&]{ return to_tf32(dist(rng)); }); + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = std::fma(static_cast(A[m * K + k]), + static_cast(B[k * N + n]), sum); + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + auto localErr = std::fabs(C[index] - C_ref[index]) / + std::max(std::fabs(C[index]), + std::fabs(C_ref[index])); + err = std::max(localErr, err); + if (localErr >= threshold) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (local error " << localErr << "): Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static void tf32_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "tf32_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void tf32_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "tf32_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_tf32.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperimentstf32 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + if (platformIndex >= platforms.size()) { + printf("Requested platform index is %d, but only %zu platforms were found.\n", + platformIndex, platforms.size()); + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 16: emulate_tN16 = false; break; + default: break; + } + } + + printf("NOTE: dpas is unconditionally emulated, currently!\n"); + emulate_tN16 = true; + + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + tf32_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x20) { + tf32_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + tf32_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x200) { + tf32_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + tf32_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl new file mode 100644 index 00000000..deab63df --- /dev/null +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -0,0 +1,352 @@ +#if defined(cl_intel_subgroups) + +inline int compute_m(const int num_sgs, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs; + const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start; + return m_index * tM * MM; +} + +// Emulated dpas: +__attribute__((overloadable)) +float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) +{ + float res = acc; + + res = fma(sub_group_broadcast(a, 0), b.s0, res); + res = fma(sub_group_broadcast(a, 1), b.s1, res); + res = fma(sub_group_broadcast(a, 2), b.s2, res); + res = fma(sub_group_broadcast(a, 3), b.s3, res); + res = fma(sub_group_broadcast(a, 4), b.s4, res); + res = fma(sub_group_broadcast(a, 5), b.s5, res); + res = fma(sub_group_broadcast(a, 6), b.s6, res); + res = fma(sub_group_broadcast(a, 7), b.s7, res); + + return res; +} + +__attribute__((overloadable)) +float2 emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float2 acc) +{ + float2 res = acc; + + res.s0 = fma(sub_group_broadcast(a, 0), b.s0, res.s0); + res.s0 = fma(sub_group_broadcast(a, 1), b.s1, res.s0); + res.s0 = fma(sub_group_broadcast(a, 2), b.s2, res.s0); + res.s0 = fma(sub_group_broadcast(a, 3), b.s3, res.s0); + res.s0 = fma(sub_group_broadcast(a, 4), b.s4, res.s0); + res.s0 = fma(sub_group_broadcast(a, 5), b.s5, res.s0); + res.s0 = fma(sub_group_broadcast(a, 6), b.s6, res.s0); + res.s0 = fma(sub_group_broadcast(a, 7), b.s7, res.s0); + + res.s1 = fma(sub_group_broadcast(a, 8), b.s0, res.s1); + res.s1 = fma(sub_group_broadcast(a, 9), b.s1, res.s1); + res.s1 = fma(sub_group_broadcast(a, 10), b.s2, res.s1); + res.s1 = fma(sub_group_broadcast(a, 11), b.s3, res.s1); + res.s1 = fma(sub_group_broadcast(a, 12), b.s4, res.s1); + res.s1 = fma(sub_group_broadcast(a, 13), b.s5, res.s1); + res.s1 = fma(sub_group_broadcast(a, 14), b.s6, res.s1); + res.s1 = fma(sub_group_broadcast(a, 15), b.s7, res.s1); + + return res; +} + +__attribute__((overloadable)) +float4 emu_sub_group_tf32_tf32_matrix_mad_k8(float2 a, float8 b, float4 acc) +{ + float4 res; + + res.s01 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s0, b, acc.s01); + res.s23 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s1, b, acc.s23); + + return res; +} + +__attribute__((overloadable)) +float8 emu_sub_group_tf32_tf32_matrix_mad_k8(float4 a, float8 b, float8 acc) +{ + float8 res; + + res.s01 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s0, b, acc.s01); + res.s23 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s1, b, acc.s23); + res.s45 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s2, b, acc.s45); + res.s67 = emu_sub_group_tf32_tf32_matrix_mad_k8(a.s3, b, acc.s67); + + return res; +} + +// M rows x K columns +float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret = A[offset]; + + return ret; +} + +// M rows x K columns +float load_a_rowmajor_d32_m2_k8_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret = A[offset]; + + return ret; +} + +// M rows x K columns +float2 load_a_rowmajor_d32_m4_k8_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float2 ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret.s0 = A[offset]; offset += stride * 2; + ret.s1 = A[offset]; offset += stride * 2; + + return ret; +} + +// M rows x K columns +float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStart, int stride) +{ + float4 ret; + + uint offset = rowStart * stride + colStart; + offset += (get_sub_group_local_id() < 8) ? 0 : stride; + offset += (get_sub_group_local_id() % 8); + + ret.s0 = A[offset]; offset += stride * 2; + ret.s1 = A[offset]; offset += stride * 2; + ret.s2 = A[offset]; offset += stride * 2; + ret.s3 = A[offset]; offset += stride * 2; + + return ret; +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_d16_m8v2_k8v2_sg16(global float* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(A + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns: +// Each work-item loads K values. +// Stride is in units of elements. +float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, int stride) +{ + float8 ret; + + uint offset = rowStart * stride + colStart; + + ret.s0 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s1 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s2 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s3 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s4 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s5 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s6 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + ret.s7 = as_float(intel_sub_group_block_read((global uint*)B + offset)); offset += stride; + + return ret; +} + +// K rows x N columns x V tiles (in the K and N dimensions) +void prefetch_b_rowmajor_d32_k8v2_n16v2_sg16(global float* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + prefetch(B + offset, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) + +#ifdef cl_intel_subgroup_extended_block_read + +// Note for 2D block reads: +// - the tile width and height is encoded into the function name. +// - base_address is the byte address. Must be 64B aligned. +// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. +// - height is the height of the entire matrix, or equivalently the number of rows. +// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. +// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. + +// Built-in functions are: + +// #ifdef cl_intel_subgroup_extended_block_read +// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); +// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); +// #endif //defined(cl_intel_subgroup_extended_block_read) + + +// For intrinsics, the pattern is: +// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat +// - operation (optional): _transpose or _transform +// - for no transpose or transform: +// - type / elements size: _u8 or _u16 or _u32 or _u64 +// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 +// - tile width: _k64 or _k32 or _k16 or _k8 +// - number of tiles: _v2 or _v1 +// - for transpose: +// - type / element size: _u64 or _u32 +// - number of tile rows: subgroup size (16) +// - tile width: _k4 (for _u64) or _k8 (for _u32) +// - number of tiles: 1 +// - for transform: +// - type / element size: _u16 or _u8 +// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) +// - tile width: subgroup size (16) +// - number of tiles: 1 + +// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: + +uint __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint intel_subgroup_block_read_u32_m1k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint intel_subgroup_block_read_u32_m2k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; +} +uint2 intel_subgroup_block_read_u32_m4k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; +} +uint4 intel_subgroup_block_read_u32_m8k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; +} + +uint intel_subgroup_block_read_u32_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint2 intel_subgroup_block_read_u32_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint4 intel_subgroup_block_read_u32_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint8 intel_subgroup_block_read_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint8 intel_subgroup_block_read_u32_m8k8v2(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); + +void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} + +#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl new file mode 100644 index 00000000..aab546af --- /dev/null +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -0,0 +1,239 @@ +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#if !defined(KK) +#define KK 1 +#endif + +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#if !defined(cl_intel_split_work_group_barrier) +#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" +#endif +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#if !defined(SGS_PER_WG) +// Launch four subgroups per work-group, to maximize cache reuse. +#define SGS_PER_WG 4 +#endif + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG, tM, MM); + const int n = get_group_id(0) * tN * NN; + +#if 0 + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } +#endif + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { +#if 0 + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } +#endif + + float4 aData[KK][MM]; +#if 0 + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d32_m8_k8v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { +#endif + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); + } + } +#if 0 + } +#endif + + float8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d32_k8_nx(B, k + kk * tK, n + nn * tN, N); + } + } + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + } + } +} + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG, tM, MM); + const int n = get_group_id(0) * tN * NN; + +#if 0 + // Initial prefetch: + const int init_k = 0; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + } + } +#endif + + float8 sum[MM][NN]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { +#if 0 + // Next prefetch: + const int next_k = k + tK * KK; + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + } + } + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + } + } +#endif + + float4 aData[KK][MM]; +#if 0 + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { +#endif + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + //aData[kk][mm] = as_float8(intel_subgroup_block_read_u32_m8k16(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); // works + //aData[kk][mm] = as_float8(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); // doesn't work + aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + //float8 good = load_a_rowmajor_d32_m8_k8v2_sg16(A, m + mm * tM, k + kk * tK, K); + //printf("sglid = %u: test = %v8f, good= %v8f\n", get_sub_group_local_id(), aData[kk][mm], good); + } + } +#if 0 + } +#endif + + float8 bData[KK][NN]; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); + } + } + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + } + } +} + +#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl new file mode 100644 index 00000000..6ebdb41f --- /dev/null +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -0,0 +1,218 @@ +#include "matrix_helpers_tf32.cl" + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_tf32_tf32_matrix_mad_k8 +#else +#define mat_mul_sg16 intel_sub_group_tf32_tf32_matrix_mad_k8 +#endif + +kernel void tf32_naive(global float* C, global float* A, global float* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + float sum = 0; + for (int k = 0; k < K; k++) { + sum = fma(A[m * K + k], B[k * N + n], sum); + } + + C[m * N + n] = sum; +} + +// For all tf32 kernels tK == 8: +#define tK 8 + +#if defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) + +// rowmajor krenels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + float sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = load_a_rowmajor_d32_m1_k8_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = load_a_rowmajor_d32_m2_k8_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + float2 aData = load_a_rowmajor_d32_m4_k8_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + float4 aData = load_a_rowmajor_d32_m8_k8_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); +} + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = as_float(intel_subgroup_block_read_u32_m1k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float2 sum = 0; + for (int k = 0; k < K; k += tK) { + float aData = as_float(intel_subgroup_block_read_u32_m2k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float4 sum = 0; + for (int k = 0; k < K; k += tK) { + float2 aData = as_float2(intel_subgroup_block_read_u32_m4k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) +{ + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + float8 sum = 0; + for (int k = 0; k < K; k += tK) { + float4 aData = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +#endif // cl_intel_subgroup_extended_block_read + +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 6e9995f1..30f877b4 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -89,4 +89,5 @@ if(BUILD_EXTENSION_SAMPLES) add_subdirectory( 14_ooqcommandbuffers ) endif() -add_subdirectory( 99_matrixexperiments ) \ No newline at end of file +add_subdirectory( 99_matrixexperiments ) +add_subdirectory( 99_matrixexperimentstf32 ) From 4712db353ced609eb082a85a00e2f333bb2da384 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 21 Feb 2024 22:24:08 -0800 Subject: [PATCH 44/99] add more tf32 variants and enable prefetching --- samples/99_matrixexperimentstf32/main.cpp | 3 + .../matrix_helpers_tf32.cl | 4 +- .../matrix_kernel_tiled_tf32.cl | 74 +++++-------------- .../matrix_kernels_tf32.cl | 18 +++++ 4 files changed, 40 insertions(+), 59 deletions(-) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/99_matrixexperimentstf32/main.cpp index 3ec2c7cb..2cb984c8 100644 --- a/samples/99_matrixexperimentstf32/main.cpp +++ b/samples/99_matrixexperimentstf32/main.cpp @@ -585,6 +585,9 @@ int main(int argc, char** argv) tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); tf32_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); } printf("Done.\n"); diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index deab63df..d7592527 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -136,7 +136,7 @@ float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStar } // M rows x K columns x V tiles (in the M and K dimensions) -void prefetch_a_rowmajor_d16_m8v2_k8v2_sg16(global float* A, int rowStart, int colStart, int stride) +void prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(global float* A, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -166,7 +166,7 @@ float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, in } // K rows x N columns x V tiles (in the K and N dimensions) -void prefetch_b_rowmajor_d32_k8v2_n16v2_sg16(global float* B, int rowStart, int colStart, int stride) +void prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(global float* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index aab546af..ff9583cb 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -42,20 +42,18 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; -#if 0 // Initial prefetch: const int init_k = 0; for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, init_k + kk * tK, K); } } - for (int kk = 0; kk < KK; kk++) { + for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, init_k + kk * tK, n + nn * tN, N); } } -#endif float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { @@ -67,41 +65,25 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { -#if 0 // Next prefetch: const int next_k = k + tK * KK; for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, next_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, next_k + kk * tK, n + nn * tN, N); } } -#endif float4 aData[KK][MM]; -#if 0 - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d32_m8_k8v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { -#endif - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); - } + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); } -#if 0 } -#endif float8 bData[KK][NN]; for (int kk = 0; kk < KK; kk++) { @@ -143,20 +125,18 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; -#if 0 // Initial prefetch: const int init_k = 0; for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, init_k + kk * tK, K); } } - for (int kk = 0; kk < KK; kk++) { + for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, init_k + kk * tK, n + nn * tN, N); } } -#endif float8 sum[MM][NN]; for (int mm = 0; mm < MM; mm++) { @@ -168,45 +148,25 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { -#if 0 // Next prefetch: const int next_k = k + tK * KK; for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, next_k + kk * tK, K); } } - for (int kk = 0; kk < KK; kk++) { + for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, next_k + kk * tK, n + nn * tN, N); } } -#endif float4 aData[KK][MM]; -#if 0 - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { -#endif - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - //aData[kk][mm] = as_float8(intel_subgroup_block_read_u32_m8k16(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); // works - //aData[kk][mm] = as_float8(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); // doesn't work - aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); - //float8 good = load_a_rowmajor_d32_m8_k8v2_sg16(A, m + mm * tM, k + kk * tK, K); - //printf("sglid = %u: test = %v8f, good= %v8f\n", get_sub_group_local_id(), aData[kk][mm], good); - } + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); } -#if 0 } -#endif float8 bData[KK][NN]; for (int kk = 0; kk < KK; kk++) { diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl index 6ebdb41f..67f0b242 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -213,6 +213,24 @@ kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A #undef MM #undef NN +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled_tf32.cl" +#undef MM +#undef NN + #endif // defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) #undef tK From 29d8bf7b02e0a19772085dbbf7c00bd38850a6c6 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 21 Feb 2024 22:30:24 -0800 Subject: [PATCH 45/99] add a few more non-blockread tiled tests --- samples/99_matrixexperimentstf32/main.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/99_matrixexperimentstf32/main.cpp index 2cb984c8..7c664841 100644 --- a/samples/99_matrixexperimentstf32/main.cpp +++ b/samples/99_matrixexperimentstf32/main.cpp @@ -571,6 +571,9 @@ int main(int argc, char** argv) tf32_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); tf32_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); tf32_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + tf32_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x200) { From c88dc58e52cd8acd52ac263a63c0e029ccebb2bc Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 22 Feb 2024 08:55:26 -0800 Subject: [PATCH 46/99] add basic activation function support --- .../matrix_kernel_tiled.cl | 6 +++++ .../99_matrixexperiments/matrix_kernels.cl | 25 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index e64d4b21..29d00317 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -120,6 +120,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { + sum[mm][nn] = activation(sum[mm][nn]); store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -210,6 +211,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = activation(sum[mm][nn]); store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -302,6 +304,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = activation(sum[mm][nn]); store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -392,6 +395,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = activation(sum[mm][nn]); store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); } } @@ -485,6 +489,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = activation(sum[mm][nn]); intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); } } @@ -586,6 +591,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { + sum[mm][nn] = activation(sum[mm][nn]); intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); } } diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 29e0c193..ddd05382 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -23,6 +23,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, sum = fma(bf16_to_fp32(A[m * K + k]), bf16_to_fp32(B[k * N + n]), sum); } + sum = activation(sum); C[m * N + n] = sum; } @@ -51,6 +52,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } @@ -70,6 +72,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } @@ -89,6 +92,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } @@ -108,6 +112,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } @@ -129,6 +134,7 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } @@ -148,6 +154,7 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } @@ -167,6 +174,7 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } @@ -186,6 +194,7 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u sum = mat_mul_sg8(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } @@ -209,6 +218,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } @@ -228,6 +238,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } @@ -247,6 +258,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } @@ -266,6 +278,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } @@ -287,6 +300,7 @@ kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } @@ -306,6 +320,7 @@ kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } @@ -325,6 +340,7 @@ kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } @@ -344,6 +360,7 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } @@ -366,6 +383,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } @@ -386,6 +404,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } @@ -406,6 +425,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } @@ -426,6 +446,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } @@ -446,6 +467,7 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } @@ -466,6 +488,7 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } @@ -486,6 +509,7 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } @@ -506,6 +530,7 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } From 083a94610844afcc55c7ac6e17e476238544ff48 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 22 Feb 2024 09:19:08 -0800 Subject: [PATCH 47/99] add support for prefetching multiple iterations ahead --- .../matrix_kernel_tiled.cl | 108 +++++++++++------- 1 file changed, 64 insertions(+), 44 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 29d00317..5f63ef0e 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -33,6 +33,10 @@ #define SGS_PER_WG 4 #endif +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + #if HAS_SIMD8 __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) @@ -229,16 +233,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -252,17 +259,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; short8 aData[KK][MM]; if (KK % 2 == 0) { @@ -320,16 +328,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -343,17 +354,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; short8 aData[KK][MM]; if (KK % 2 == 0) { @@ -414,16 +426,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -437,17 +452,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; short8 aData[KK][MM]; if (KK % 2 == 0) { @@ -506,16 +522,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -529,17 +548,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; short8 aData[KK][MM]; if (KK % 2 == 0) { From d5c3d6d542f239692a84627403749f038f9faad8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 27 Feb 2024 21:35:13 -0800 Subject: [PATCH 48/99] add support for even bigger block reads --- .../99_matrixexperiments/matrix_helpers.cl | 38 +++++++++- .../matrix_kernel_tiled.cl | 72 ++++++++++++++++++- 2 files changed, 107 insertions(+), 3 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 0a19797a..34857c0a 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -580,13 +580,20 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - tile width: subgroup size (16) // - number of tiles: 1 -// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: +typedef ushort __attribute__((ext_vector_type(32))) ushort32; +typedef ushort __attribute__((ext_vector_type(64))) ushort64; + +// Define block reads and writes. These are supported by the hardware but are not in the headers: ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -616,6 +623,35 @@ ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, { return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0] = tmp.lo.lo; + dst[1] = tmp.lo.hi; + dst[2] = tmp.hi.lo; + dst[3] = tmp.hi.hi; +} + +void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo; + dst[0][1] = tmp.lo.hi; + dst[1][0] = tmp.hi.lo; + dst[1][1] = tmp.hi.hi; +} +void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) +{ + ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo.lo; + dst[0][1] = tmp.lo.lo.hi; + dst[0][2] = tmp.lo.hi.lo; + dst[0][3] = tmp.lo.hi.hi; + dst[1][0] = tmp.hi.lo.lo; + dst[1][1] = tmp.hi.lo.hi; + dst[1][2] = tmp.hi.hi.lo; + dst[1][3] = tmp.hi.hi.hi; +} uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) { diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 5f63ef0e..963a2cc2 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -466,7 +466,31 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); @@ -474,6 +498,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN aData[kk + 1][mm] = aTemp.hi; } } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { @@ -562,7 +596,31 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); @@ -570,6 +628,16 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl aData[kk + 1][mm] = aTemp.hi; } } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { From 18096ee9e9539e3cb80032d562f729529c9e2f83 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 12:57:49 -0800 Subject: [PATCH 49/99] add a way to generate tf32 dpas currently (disabled by default) --- samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index d7592527..6660340c 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -13,6 +13,7 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) { float res = acc; +#if 1 res = fma(sub_group_broadcast(a, 0), b.s0, res); res = fma(sub_group_broadcast(a, 1), b.s1, res); res = fma(sub_group_broadcast(a, 2), b.s2, res); @@ -21,6 +22,12 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) res = fma(sub_group_broadcast(a, 5), b.s5, res); res = fma(sub_group_broadcast(a, 6), b.s6, res); res = fma(sub_group_broadcast(a, 7), b.s7, res); +#else +float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short a, int8 b, float acc); + uint a_ui = as_uint(sub_group_shuffle(a, get_sub_group_local_id() / 2)); + short aData = get_sub_group_local_id() % 2 ? as_short2(a_ui).hi : as_short2(a_ui).lo; + res = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(aData, as_int8(b), res); +#endif return res; } From 459c109dc032a4f5fe33038851755583a8612fa9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 17:14:54 -0800 Subject: [PATCH 50/99] increase prefetch distance add helper functions for tiled kernels --- .../99_matrixexperiments/matrix_helpers.cl | 21 +- .../matrix_kernel_tiled.cl | 490 ++++++++---------- 2 files changed, 244 insertions(+), 267 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 34857c0a..5b323e14 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -289,7 +289,8 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(A + offset, 1); + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -377,7 +378,8 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(A + offset, 1); + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -446,8 +448,10 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(B + offset, 1); offset += 8 * stride; - prefetch(B + offset, 1); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); offset += 8 * stride; #endif // defined(PREFETCH_DEFAULT) } @@ -456,7 +460,8 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; - prefetch(B + offset, 1); + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 8); #endif // defined(PREFETCH_DEFAULT) } @@ -466,7 +471,8 @@ void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colSta #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; - prefetch(B_ui + offset_ui, 1); + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 4); #endif // defined(PREFETCH_DEFAULT) } @@ -476,7 +482,8 @@ void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colS #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; - prefetch(B_ui + offset_ui, 1); + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 4); #endif // defined(PREFETCH_DEFAULT) } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 963a2cc2..0655d109 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -28,6 +28,9 @@ #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + #if !defined(SGS_PER_WG) // Launch four subgroups per work-group, to maximize cache reuse. #define SGS_PER_WG 4 @@ -39,6 +42,33 @@ #if HAS_SIMD8 +void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + } + } +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -49,16 +79,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -72,17 +105,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; int8 aData[KK][MM]; if (KK % 2 == 0) { @@ -140,16 +174,19 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* const int n = get_group_id(0) * tN * NN; // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + } } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + } } + prefetch_k += tK * KK; } float8 sum[MM][NN]; @@ -163,17 +200,18 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; + // TODO: skip prefetch on the last iterations. for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, next_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, next_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } + prefetch_k += tK * KK; int8 aData[KK][MM]; if (KK % 2 == 0) { @@ -223,6 +261,70 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* #endif // HAS_SIMD8 +void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -235,16 +337,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -260,41 +354,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -330,16 +398,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -355,41 +415,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -415,8 +449,90 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read +void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } + } + } +} + +void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); + } + } +} + +void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[kk + 0][nn] = bTemp.lo; + bData[kk + 1][nn] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) + { const int tM = 8; const int tN = 16; @@ -428,16 +544,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -453,75 +561,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); - } - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - } - } - } + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); - } - } + HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -558,16 +606,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -583,85 +623,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; - if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); - } - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - } - } - } + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); int8 bData[KK][NN]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - bData[kk + 0][nn] = bTemp.lo; - bData[kk + 1][nn] = bTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - } - } - } + HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { From b034c1d0c6987b23758e9e960cd8cc2c3877485e Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 17:19:32 -0800 Subject: [PATCH 51/99] fix DG2 prefetches --- .../99_matrixexperiments/matrix_kernel_tiled.cl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 0655d109..75ab95b5 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -46,7 +46,7 @@ void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } } @@ -55,7 +55,7 @@ void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -64,7 +64,7 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -83,12 +83,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int p = 0; p < PREFETCH_DISTANCE; p++) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } prefetch_k += tK * KK; @@ -178,12 +178,12 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int p = 0; p < PREFETCH_DISTANCE; p++) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } prefetch_k += tK * KK; From 697754d45f8109124e68f9e7d051ecf129dc78c3 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 17:31:09 -0800 Subject: [PATCH 52/99] use more helper functions for DG2 tiled kernels --- .../matrix_kernel_tiled.cl | 147 ++++++------------ 1 file changed, 49 insertions(+), 98 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 75ab95b5..4e7db7b2 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -40,6 +40,24 @@ #define PREFETCH_DISTANCE 1 #endif +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + #if HAS_SIMD8 void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) @@ -69,6 +87,25 @@ void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int } } +void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { @@ -81,16 +118,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -106,41 +135,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; int8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { @@ -176,16 +179,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -201,41 +196,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int k = 0; k < K; k += tK * KK) { // Next prefetch: // TODO: skip prefetch on the last iterations. - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; int8 aData[KK][MM]; - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); - } - } - } + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); int8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -307,24 +276,6 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i } } -void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { From cbbdcb6ae8c6394ed7058bd44a2cfd4a48e7e7b2 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 1 Mar 2024 18:12:37 -0800 Subject: [PATCH 53/99] switch back to a smaller prefetch The smaller prefetch is faster. Need to understand why. --- samples/99_matrixexperiments/matrix_helpers.cl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 5b323e14..ee716819 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -290,7 +290,7 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 8); + prefetch(A + offset, 1); #endif // defined(PREFETCH_DEFAULT) } @@ -379,7 +379,7 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 8); + prefetch(A + offset, 1); #endif // defined(PREFETCH_DEFAULT) } @@ -449,9 +449,9 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 8); offset += 8 * stride; + prefetch(B + offset, 1); offset += 8 * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 8); offset += 8 * stride; + prefetch(B + offset, 1); offset += 8 * stride; #endif // defined(PREFETCH_DEFAULT) } @@ -461,7 +461,7 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 8); + prefetch(B + offset, 1); #endif // defined(PREFETCH_DEFAULT) } @@ -472,7 +472,7 @@ void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colSta global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); - prefetch(B_ui + offset_ui, 4); + prefetch(B_ui + offset_ui, 1); #endif // defined(PREFETCH_DEFAULT) } @@ -483,7 +483,7 @@ void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colS global uint* B_ui = (global uint*)B; uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); - prefetch(B_ui + offset_ui, 4); + prefetch(B_ui + offset_ui, 1); #endif // defined(PREFETCH_DEFAULT) } From 2d23f76891e51a98dd64e62078c4613f44fa2693 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 2 Mar 2024 17:33:30 -0800 Subject: [PATCH 54/99] add support for 2D block prefetches --- .../99_matrixexperiments/matrix_helpers.cl | 76 +++++++++++++++++-- .../matrix_kernel_tiled.cl | 53 ++++++++++--- 2 files changed, 113 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ee716819..97b2ce5e 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -290,7 +290,7 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 1); + prefetch(A + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -379,7 +379,7 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(A + offset) % 4 == 0); - prefetch(A + offset, 1); + prefetch(A + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -449,9 +449,9 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); offset += 8 * stride; + prefetch(B + offset, 2); offset += 8 * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); offset += 8 * stride; + prefetch(B + offset, 2); offset += 8 * stride; #endif // defined(PREFETCH_DEFAULT) } @@ -461,7 +461,7 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; __builtin_assume((ulong)(B + offset) % 4 == 0); - prefetch(B + offset, 1); + prefetch(B + offset, 2); #endif // defined(PREFETCH_DEFAULT) } @@ -587,10 +587,21 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - tile width: subgroup size (16) // - number of tiles: 1 +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + typedef ushort __attribute__((ext_vector_type(32))) ushort32; typedef ushort __attribute__((ext_vector_type(64))) ushort64; -// Define block reads and writes. These are supported by the hardware but are not in the headers: +// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -605,6 +616,19 @@ ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + + void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); @@ -669,6 +693,46 @@ uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, i return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C + +void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} +void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +} + + void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 4e7db7b2..f1bb1893 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -481,6 +481,41 @@ void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, in } } +void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) @@ -492,11 +527,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; - // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -510,11 +544,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); @@ -522,6 +553,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int8 bData[KK][NN]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { @@ -554,11 +588,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl const int m = compute_m(SGS_PER_WG, tM, MM); const int n = get_group_id(0) * tN * NN; - // Initial prefetch: int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -572,11 +605,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); @@ -584,6 +614,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int8 bData[KK][NN]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { From 14bc83e8a18329aa6b92fbfe94f823653f81bd63 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 4 Mar 2024 11:52:02 -0800 Subject: [PATCH 55/99] add support for bigger transformed block reads Note: this support requires very recent drivers! --- .../99_matrixexperiments/matrix_helpers.cl | 49 +++++++++++++++++++ .../matrix_kernel_tiled.cl | 37 ++++++++++++-- 2 files changed, 83 insertions(+), 3 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 97b2ce5e..1164c24c 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -601,6 +601,8 @@ enum LSC_LDCC { typedef ushort __attribute__((ext_vector_type(32))) ushort32; typedef ushort __attribute__((ext_vector_type(64))) ushort64; +typedef uint __attribute__((ext_vector_type(32))) uint32; + // Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -616,6 +618,11 @@ ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); @@ -693,43 +700,85 @@ uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, i return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } +// Each block is K rows x N columns, where the K rows have been VNNI transformed. +int8 intel_subgroup_block_read_transform_u16_k16n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: + return as_int8(intel_subgroup_block_read_transform_u16_k16(base_address, width, height, pitch, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k32n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k16n16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) +{ + uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = as_int8(tmp.lo.lo); + dst[0][1] = as_int8(tmp.lo.hi); + dst[1][0] = as_int8(tmp.hi.lo); + dst[1][1] = as_int8(tmp.hi.hi); +} + + #define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) { +#if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index f1bb1893..fda52ca4 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -453,11 +453,42 @@ void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM } } +// TODO: consider swapping KK and NN order! void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK))); + if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn+=2) { + int8 tmp[2][2]; + intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); + for (int tnn = 0; tnn < 2; tnn++) { + for (int tkk = 0; tkk < 2; tkk++) { + bData[kk + tkk][nn + tnn] = tmp[tnn][tkk]; + } + } + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[kk][nn + 0] = bTemp.lo; + bData[kk][nn + 1] = bTemp.hi; + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[kk + 0][nn] = bTemp.lo; + bData[kk + 1][nn] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[kk][nn] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } } } } From 02a6045c60486130dbae86c84e2922d105dd7a65 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 4 Mar 2024 12:58:06 -0800 Subject: [PATCH 56/99] swap the B matrix NN and KK tiling dimensions This will keep the dimensions in a consistent row and column order --- .../matrix_kernel_tiled.cl | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index fda52ca4..2b3e7beb 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -40,20 +40,20 @@ #define PREFETCH_DISTANCE 1 #endif -void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + bData[nn][kk] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); } } } -void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[KK][NN]) +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + bData[nn][kk] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); } } } @@ -142,13 +142,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl int8 aData[KK][MM]; HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } @@ -203,13 +203,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* int8 aData[KK][MM]; HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } @@ -312,13 +312,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f short8 aData[KK][MM]; HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } @@ -373,13 +373,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float short8 aData[KK][MM]; HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } @@ -453,8 +453,7 @@ void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM } } -// TODO: consider swapping KK and NN order! -void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -463,7 +462,7 @@ void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); for (int tnn = 0; tnn < 2; tnn++) { for (int tkk = 0; tkk < 2; tkk++) { - bData[kk + tkk][nn + tnn] = tmp[tnn][tkk]; + bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; } } } @@ -472,41 +471,41 @@ void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[kk][nn + 0] = bTemp.lo; - bData[kk][nn + 1] = bTemp.hi; + bData[nn + 0][kk] = bTemp.lo; + bData[nn + 1][kk] = bTemp.hi; } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[kk + 0][nn] = bTemp.lo; - bData[kk + 1][nn] = bTemp.hi; + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn][kk] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } } -void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[KK][NN]) +void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - bData[kk + 0][nn] = bTemp.lo; - bData[kk + 1][nn] = bTemp.hi; + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[nn][kk] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); } } } @@ -581,7 +580,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); @@ -590,7 +589,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } @@ -642,7 +641,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); - int8 bData[KK][NN]; + int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); @@ -651,7 +650,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } } From a0c2e53571adc089671579b2d996fab3b694d5c5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 4 Mar 2024 16:09:52 -0800 Subject: [PATCH 57/99] remove the 8x2 tiled kernels The 4x4 tiled kernels should be the best performing. --- samples/99_matrixexperiments/main.cpp | 6 ------ samples/99_matrixexperiments/matrix_kernels.cl | 6 ------ 2 files changed, 12 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 23bf2d87..8b8385ee 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -789,7 +789,6 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 8, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x8) { @@ -807,7 +806,6 @@ int main(int argc, char** argv) bfloat16_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 8, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } if (mask & 0x20) { @@ -825,7 +823,6 @@ int main(int argc, char** argv) bfloat16_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_rowmajor_tiled<8, 16, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x80) { @@ -843,7 +840,6 @@ int main(int argc, char** argv) bfloat16_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_vnni_tiled<8, 16, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } if (mask & 0x200) { @@ -861,7 +857,6 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 8, 2>(context, program, queue, C, A, B, M, N, K, C_ref); } if (mask & 0x800) { @@ -879,7 +874,6 @@ int main(int argc, char** argv) bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); bfloat16_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - bfloat16_dpas_blockread_vnni_tiled<8, 16, 8, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } printf("Done.\n"); diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index ddd05382..42e31814 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -580,12 +580,6 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* #undef MM #undef NN -#define MM 8 -#define NN 2 -#include "matrix_kernel_tiled.cl" -#undef MM -#undef NN - #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #undef tK From 873c6ab36eb9c707a6c1e1b3efce09af058d425b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 4 Mar 2024 16:34:00 -0800 Subject: [PATCH 58/99] try a different order for prefetches and loads For the block read kernels, the order is now: 1. Prefetch A 2. Load B 3. Load A 4. Prefetch B --- .../matrix_kernel_tiled.cl | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 2b3e7beb..ef065c95 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -559,8 +559,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -575,16 +575,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; + short8 aData[KK][MM]; + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -592,6 +589,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } + if (kk == 0) { + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } } split_barrier_wait(); @@ -636,16 +637,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; + short8 aData[KK][MM]; + HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -653,6 +651,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } + if (kk == 0) { + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } } split_barrier_wait(); From ef205e3fe680d517d1f2be6ca24337292227c8e0 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 4 Mar 2024 17:56:09 -0800 Subject: [PATCH 59/99] add support for more block prefetches --- .../99_matrixexperiments/matrix_helpers.cl | 15 ++++++ .../matrix_kernel_tiled.cl | 54 +++++++++++++++++-- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index 1164c24c..e4ce38df 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -635,6 +635,9 @@ void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); @@ -780,6 +783,18 @@ void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_addres __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } +void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index ef065c95..5d94faf6 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -546,6 +546,52 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM } } +void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn += 2) { + intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } +} + __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) @@ -560,7 +606,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -590,7 +636,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } if (kk == 0) { - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); prefetch_k += tK * KK; } } @@ -621,8 +667,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -652,7 +698,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl } } if (kk == 0) { - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); prefetch_k += tK * KK; } } From 2ea4d56f138fad051a96191553cb18329862bd6c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 6 Mar 2024 10:12:47 -0800 Subject: [PATCH 60/99] switch the prefetch order back for now At least for now, we will follow the order: 1. Load B Matrix Tile 2. Load A matrix Tile 3. Prefetch Next B Matrix Tile 4. Prefetch Next A Matrix Tile 5. Compute 6. Loop back to 1. --- .../matrix_kernel_tiled.cl | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 5d94faf6..3421ee40 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -605,8 +605,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -620,25 +620,22 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } - if (kk == 0) { - HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); - prefetch_k += tK * KK; - } } split_barrier_wait(); @@ -667,8 +664,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -682,25 +679,23 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - int8 bData[NN][KK]; HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + // TODO: skip prefetch on the last iterations. + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); } } - if (kk == 0) { - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); - prefetch_k += tK * KK; - } } split_barrier_wait(); From f60930a6bffd4b123c7fa216fa1cd6c22203980f Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 6 Mar 2024 11:01:35 -0800 Subject: [PATCH 61/99] add support for larger work-groups in both dimensions --- .../99_matrixexperiments/matrix_helpers.cl | 13 ++++-- .../matrix_kernel_tiled.cl | 45 ++++++++++--------- 2 files changed, 34 insertions(+), 24 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index e4ce38df..e8a08ffb 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -55,13 +55,20 @@ float8 activation(float8 f) typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); -inline int compute_m(const int num_sgs, const int tM, const int MM) +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) { - const int m_start = get_group_id(1) * num_sgs; - const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start; + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; return m_index * tM * MM; } +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + // Emulated SIMD8 dpas: __attribute__((overloadable)) float emu_sub_group_bf16_bf16_matrix_mad_k16(int a, int8 b, float acc) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 3421ee40..beea5abd 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -31,9 +31,12 @@ #define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN #define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) -#if !defined(SGS_PER_WG) -// Launch four subgroups per work-group, to maximize cache reuse. -#define SGS_PER_WG 4 +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 #endif #if !defined(PREFETCH_DISTANCE) @@ -106,14 +109,14 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int } } -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -167,14 +170,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl } } -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -276,14 +279,14 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -337,14 +340,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); // Initial prefetch: int prefetch_k = 0; @@ -592,7 +595,7 @@ void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, in } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { @@ -600,8 +603,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { @@ -652,15 +655,15 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN } } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { From 54b9366e673a2514123d4680bf1ca9bd0a126be3 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 6 Mar 2024 23:40:20 -0800 Subject: [PATCH 62/99] add support for initializing matrices with zero data --- samples/99_matrixexperiments/main.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index 8b8385ee..d79ca8e0 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -20,6 +20,7 @@ using test_clock = std::chrono::high_resolution_clock; +bool zeroData = false; bool identityData = false; bool fixedData = false; bool validate = false; @@ -77,7 +78,10 @@ static size_t findMinSubGroupSize(cl::Device& device) template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { - if (identityData) { + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0.0f; }); + } + else if (identityData) { std::generate(std::begin(M), std::end(M), [&]{ return 1.0f; }); } else if (fixedData) { for (size_t r = 0; r < numRows; r++) { @@ -648,6 +652,7 @@ int main(int argc, char** argv) op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); op.add("", "identity", "Use Identity Data", &identityData); op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); From 8d23a8c7ebe496824471caf24227a8c046b210e7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 7 Mar 2024 08:14:36 -0800 Subject: [PATCH 63/99] add support for setting round robin scheduling (disabled by default) --- samples/99_matrixexperiments/CMakeLists.txt | 2 +- samples/99_matrixexperiments/main.cpp | 26 +++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt index 86599fbf..9fe36d84 100644 --- a/samples/99_matrixexperiments/CMakeLists.txt +++ b/samples/99_matrixexperiments/CMakeLists.txt @@ -6,6 +6,6 @@ add_opencl_sample( TEST NUMBER 99 TARGET matrixexperiments - VERSION 120 + VERSION 200 # for clSetKernelExecInfo SOURCES main.cpp KERNELS matrix_helpers.cl matrix_kernels.cl matrix_kernel_tiled.cl) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index d79ca8e0..fae01b3f 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -27,6 +27,7 @@ bool validate = false; bool emulate = false; bool wallclock = false; bool skipinit = false; +bool roundRobin = false; int testIterations = 16; float threshold = 0.01f; @@ -75,6 +76,18 @@ static size_t findMinSubGroupSize(cl::Device& device) return 0; } +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { @@ -440,6 +453,9 @@ static void bfloat16_dpas_blockread_rowmajor( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -496,6 +512,9 @@ static void bfloat16_dpas_blockread_rowmajor_tiled( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -546,6 +565,9 @@ static void bfloat16_dpas_blockread_vnni( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -602,6 +624,9 @@ static void bfloat16_dpas_blockread_vnni_tiled( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -658,6 +683,7 @@ int main(int argc, char** argv) op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add("", "wallclock", "Measure Wallclock Time", &wallclock); op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); bool printUsage = false; From 133fb58a22d9a1a2d2e1f59b99976b29a4737bc5 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 7 Mar 2024 14:59:55 -0800 Subject: [PATCH 64/99] tell the compiler K is always greater than zero This probably won't affect performance, but it does enable more concise code, becuase the compiler will know the K loop will always execute at least once. --- .../99_matrixexperiments/matrix_helpers.cl | 7 ++++++ .../matrix_kernel_tiled.cl | 6 +++++ .../99_matrixexperiments/matrix_kernels.cl | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index e8a08ffb..dd7dbba4 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -51,6 +51,13 @@ float8 activation(float8 f) return res; } +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) typedef global ushort* global_aligned_ushort_ptr __attribute__((align_value(4))); diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index beea5abd..80274741 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -112,6 +112,7 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; @@ -173,6 +174,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 8; const int N = get_global_size(0) * NN; @@ -282,6 +284,7 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; @@ -343,6 +346,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int N = get_global_size(0) * NN; @@ -599,6 +603,7 @@ __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_si kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; @@ -658,6 +663,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 42e31814..f2254553 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -39,6 +39,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 8; const int N = get_global_size(0); @@ -59,6 +60,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 8; const int N = get_global_size(0); @@ -79,6 +81,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 8; const int N = get_global_size(0); @@ -99,6 +102,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 8; const int N = get_global_size(0); @@ -121,6 +125,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 8; const int N = get_global_size(0); @@ -141,6 +146,7 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 8; const int N = get_global_size(0); @@ -161,6 +167,7 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 8; const int N = get_global_size(0); @@ -181,6 +188,7 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 8; const int N = get_global_size(0); @@ -205,6 +213,7 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int N = get_global_size(0); @@ -225,6 +234,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int N = get_global_size(0); @@ -245,6 +255,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int N = get_global_size(0); @@ -265,6 +276,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int N = get_global_size(0); @@ -287,6 +299,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int N = get_global_size(0); @@ -307,6 +320,7 @@ kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int N = get_global_size(0); @@ -327,6 +341,7 @@ kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int N = get_global_size(0); @@ -347,6 +362,7 @@ kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int N = get_global_size(0); @@ -369,6 +385,7 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int M = get_global_size(1); @@ -390,6 +407,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int M = get_global_size(1) * tM; @@ -411,6 +429,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int M = get_global_size(1) * tM; @@ -432,6 +451,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM; @@ -453,6 +473,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int M = get_global_size(1) * tM; @@ -474,6 +495,7 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int M = get_global_size(1) * tM; @@ -495,6 +517,7 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int M = get_global_size(1) * tM; @@ -516,6 +539,7 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* A, global ushort* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM; From 3989cae48c43731d69f275d5ef083f00a95b6d0f Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 7 Mar 2024 18:11:10 -0800 Subject: [PATCH 65/99] switch the sum dimensions for consistency --- .../99_matrixexperiments/matrix_helpers.cl | 5 ++ .../matrix_kernel_tiled.cl | 60 +++++++++---------- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index dd7dbba4..acf1219a 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -657,6 +657,7 @@ void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int wid void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); +void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { @@ -827,5 +828,9 @@ void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width { __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } +void intel_subgroup_block_write_u32_m16k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} #endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 80274741..808796b1 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -127,10 +127,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -152,7 +152,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -165,8 +165,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = activation(sum[mm][nn]); - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -189,10 +189,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -214,7 +214,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -227,8 +227,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = activation(sum[mm][nn]); - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -299,10 +299,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -324,7 +324,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -337,8 +337,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = activation(sum[mm][nn]); - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -361,10 +361,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -386,7 +386,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -399,8 +399,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = activation(sum[mm][nn]); - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -618,10 +618,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -641,7 +641,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -654,8 +654,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = activation(sum[mm][nn]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } @@ -678,10 +678,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -702,7 +702,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -715,8 +715,8 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = activation(sum[mm][nn]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } From 7c0c358cf1fc7f13e74f067f443d488145e97939 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 8 Mar 2024 13:55:37 -0800 Subject: [PATCH 66/99] try a cooperative prefetch for the B matrix tile --- .../99_matrixexperiments/matrix_kernel_tiled.cl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 808796b1..1f8c08cc 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -555,7 +555,13 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0 & NN % 2 == 0) { + const int NUM_SGS = SGS_PER_WG_X * SGS_PER_WG_Y; + if (KK % 2 == 0 & NN == 4 & NUM_SGS >= 2) { + const int nn = (get_sub_group_id() % 2) * 2; + for (int kk = 0; kk < KK; kk+=2) { + intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); @@ -584,7 +590,13 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0) { + const int NUM_SGS = SGS_PER_WG_X * SGS_PER_WG_Y; + if (KK % 2 == 0 & NN == 4 & NUM_SGS >= 4) { + const int nn = get_sub_group_id() % 4; + for (int kk = 0; kk < KK; kk+=2) { + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); From 9a0317a38bf3b1d83ad99a91085b0fb92cf32d37 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 8 Mar 2024 14:51:26 -0800 Subject: [PATCH 67/99] fix the cooperative prefetching indexing calculation --- samples/99_matrixexperiments/matrix_kernel_tiled.cl | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 1f8c08cc..da5c8e44 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -555,9 +555,8 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - const int NUM_SGS = SGS_PER_WG_X * SGS_PER_WG_Y; - if (KK % 2 == 0 & NN == 4 & NUM_SGS >= 2) { - const int nn = (get_sub_group_id() % 2) * 2; + if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 2) { + const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 2 * 2; for (int kk = 0; kk < KK; kk+=2) { intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } @@ -590,9 +589,8 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - const int NUM_SGS = SGS_PER_WG_X * SGS_PER_WG_Y; - if (KK % 2 == 0 & NN == 4 & NUM_SGS >= 4) { - const int nn = get_sub_group_id() % 4; + if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 4; for (int kk = 0; kk < KK; kk+=2) { intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } From afca843783de435c14200f69ad22670dd71ea584 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 13 Mar 2024 11:42:04 -0700 Subject: [PATCH 68/99] a few more naming changes for consistency --- .../99_matrixexperiments/matrix_kernel_tiled.cl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index da5c8e44..39d23adb 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -407,7 +407,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float #ifdef cl_intel_subgroup_extended_block_read -void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) { if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -460,7 +460,7 @@ void HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(global ushort* A, int tM } } -void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -499,7 +499,7 @@ void HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(global ushort* B, int tN } } -void HELPER_NAME(btile_load_blockread_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { @@ -611,7 +611,6 @@ void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, in __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) - { __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; @@ -639,10 +638,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int k = 0; k < K; k += tK * KK) { int8 bData[NN][KK]; - HELPER_NAME(btile_load_blockread_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); @@ -699,10 +698,10 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { int8 bData[NN][KK]; - HELPER_NAME(btile_load_blockread_vnni, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; - HELPER_NAME(atile_load_blockread_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); // TODO: skip prefetch on the last iterations. HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); From 7176d5368f96160947f857d091bb838dbaecc3a8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 13 Mar 2024 17:20:59 -0700 Subject: [PATCH 69/99] try a smaller cooperative prefetch for the B matrix for the rowmajor case --- samples/99_matrixexperiments/main.cpp | 12 ++++----- .../matrix_kernel_tiled.cl | 26 ++++++++++++------- 2 files changed, 22 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index fae01b3f..a0025147 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -881,12 +881,12 @@ int main(int argc, char** argv) } if (mask & 0x400) { - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 39d23adb..a8d6b4dc 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -465,6 +465,9 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { + //if (get_sub_group_local_id() == 0) { + // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); + //} int8 tmp[2][2]; intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); for (int tnn = 0; tnn < 2; tnn++) { @@ -555,11 +558,14 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 2) { - const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 2 * 2; - for (int kk = 0; kk < KK; kk+=2) { - intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... + const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... + //if (get_sub_group_local_id() == 0) { + // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { @@ -589,11 +595,11 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { - if (KK % 2 == 0 & NN == 4 & SGS_PER_WG_Y >= 4) { - const int nn = (get_sub_group_id() / SGS_PER_WG_X) % 4; - for (int kk = 0; kk < KK; kk+=2) { - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); - } + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 + const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { From 90b23b009bb09beb4872db508c7d34a08d2a1ee7 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 14 Mar 2024 10:57:42 -0700 Subject: [PATCH 70/99] re-enable all tiled matrix scenarios --- samples/99_matrixexperiments/main.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/99_matrixexperiments/main.cpp index a0025147..fae01b3f 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/99_matrixexperiments/main.cpp @@ -881,12 +881,12 @@ int main(int argc, char** argv) } if (mask & 0x400) { - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - //bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); bfloat16_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); } From ab01142a5052d9d0b64b8df64a2642355e6eaa44 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 14 Mar 2024 13:51:52 -0700 Subject: [PATCH 71/99] try a cooperative prefetch for the A matrix tile --- samples/99_matrixexperiments/matrix_kernel_tiled.cl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index a8d6b4dc..2f7fdc72 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -412,6 +412,9 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} ushort8 tmp[2][4]; intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); for (int tkk = 0; tkk < 2; tkk++) { @@ -523,7 +526,15 @@ void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) { - if (KK % 2 == 0 & MM % 4 == 0) { + if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { + const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) + const int kk = 0; + const int mm = sg_index_x % 2 * 2; + //if (get_sub_group_local_id() == 0) { + // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); From d6858fa9aed4ba1e1550f763ba9e6d6ea27258f6 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 14 Mar 2024 14:45:39 -0700 Subject: [PATCH 72/99] try a slightly smaller A tile prefetch --- samples/99_matrixexperiments/matrix_kernel_tiled.cl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 2f7fdc72..75940825 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -529,11 +529,11 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) const int kk = 0; - const int mm = sg_index_x % 2 * 2; + const int mm = sg_index_x % 4; //if (get_sub_group_local_id() == 0) { // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); //} - intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { From 8cbc8db88911862466d9a06624cc86920b49b8e8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 9 Apr 2024 16:58:25 -0700 Subject: [PATCH 73/99] sync tf32 samples with bf16 samples --- .../99_matrixexperimentstf32/CMakeLists.txt | 2 +- samples/99_matrixexperimentstf32/main.cpp | 27 ++- .../matrix_helpers_tf32.cl | 66 +++++- .../matrix_kernel_tiled_tf32.cl | 198 ++++++++++-------- .../matrix_kernels_tf32.cl | 25 ++- 5 files changed, 222 insertions(+), 96 deletions(-) diff --git a/samples/99_matrixexperimentstf32/CMakeLists.txt b/samples/99_matrixexperimentstf32/CMakeLists.txt index de636108..5987f780 100644 --- a/samples/99_matrixexperimentstf32/CMakeLists.txt +++ b/samples/99_matrixexperimentstf32/CMakeLists.txt @@ -6,6 +6,6 @@ add_opencl_sample( TEST NUMBER 99 TARGET matrixexperimentstf32 - VERSION 120 + VERSION 200 # for clSetKernelExecInfo SOURCES main.cpp KERNELS matrix_helpers_tf32.cl matrix_kernels_tf32.cl matrix_kernel_tiled_tf32.cl) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/99_matrixexperimentstf32/main.cpp index 7c664841..de21623b 100644 --- a/samples/99_matrixexperimentstf32/main.cpp +++ b/samples/99_matrixexperimentstf32/main.cpp @@ -21,12 +21,14 @@ using test_clock = std::chrono::high_resolution_clock; +bool zeroData = false; bool identityData = false; bool fixedData = false; bool validate = false; bool emulate = false; bool wallclock = false; bool skipinit = false; +bool roundRobin = false; int testIterations = 16; float threshold = 0.01f; @@ -75,6 +77,18 @@ static size_t findMinSubGroupSize(cl::Device& device) return 0; } +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + float to_tf32(float f) { union { @@ -96,7 +110,10 @@ float to_tf32(float f) template static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) { - if (identityData) { + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0.0f; }); + } + else if (identityData) { std::generate(std::begin(M), std::end(M), [&]{ return to_tf32(1.0f); }); } else if (fixedData) { for (size_t r = 0; r < numRows; r++) { @@ -334,6 +351,9 @@ static void tf32_dpas_blockread_rowmajor( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -390,6 +410,9 @@ static void tf32_dpas_blockread_rowmajor_tiled( kernel.setArg(1, A); kernel.setArg(2, B); kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } if (!skipinit) { queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); @@ -440,11 +463,13 @@ int main(int argc, char** argv) op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); op.add("", "identity", "Use Identity Data", &identityData); op.add("", "fixed", "Use Fixed Data", &fixedData); op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); op.add("", "wallclock", "Measure Wallclock Time", &wallclock); op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); bool printUsage = false; diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index 6660340c..43174515 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -1,12 +1,70 @@ +__attribute__((overloadable)) +float activation(float f) +{ +#if defined(ACTIVATION_RELU) + return fmax(f, 0); +#else // identity + return f; +#endif +} + +__attribute__((overloadable)) +float2 activation(float2 f) +{ + float2 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + return res; +} + +__attribute__((overloadable)) +float4 activation(float4 f) +{ + float4 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + return res; +} + +float8 activation(float8 f) +{ + float8 res; + res.s0 = activation(f.s0); + res.s1 = activation(f.s1); + res.s2 = activation(f.s2); + res.s3 = activation(f.s3); + res.s4 = activation(f.s4); + res.s5 = activation(f.s5); + res.s6 = activation(f.s6); + res.s7 = activation(f.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + #if defined(cl_intel_subgroups) -inline int compute_m(const int num_sgs, const int tM, const int MM) +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) { - const int m_start = get_group_id(1) * num_sgs; - const int m_index = num_sgs > 1 ? m_start + get_sub_group_id() : m_start; + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; return m_index * tM * MM; } +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + // Emulated dpas: __attribute__((overloadable)) float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) @@ -87,8 +145,8 @@ float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart { float ret; + // Note: only the low eight channels should be used. uint offset = rowStart * stride + colStart; - offset += (get_sub_group_local_id() < 8) ? 0 : stride; offset += (get_sub_group_local_id() % 8); ret = A[offset]; diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index ff9583cb..db0bb27c 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -28,37 +28,79 @@ #define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN #define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) -#if !defined(SGS_PER_WG) -// Launch four subgroups per work-group, to maximize cache reuse. -#define SGS_PER_WG 4 +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 #endif -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) -kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 +#endif + +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global float* B, int tN, int N, int k, int n, float8 bData[NN][KK]) { - const int tM = 8; - const int tN = 16; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_rowmajor_d32_k8_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} - // Initial prefetch: - const int init_k = 0; +void HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(global float* A, int tM, int K, int m, int prefetch_k) +{ for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, init_k + kk * tK, K); + prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(global float* B, int tN, int N, int prefetch_k, int n) +{ for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, init_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor_sg16, MM, NN)(global float* A, int tM, int K, int m, int k, float4 aData[KK][MM]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); } } +} - float8 sum[MM][NN]; +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } @@ -66,36 +108,21 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float for (int k = 0; k < K; k += tK * KK) { // Next prefetch: - const int next_k = k + tK * KK; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, next_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, next_k + kk * tK, n + nn * tN, N); - } - } + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; float4 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); - } - } + HELPER_NAME(atile_load_rowmajor_sg16, MM, NN)(A, tM, K, m, k, aData); - float8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = load_b_rowmajor_d32_k8_nx(B, k + kk * tK, n + nn * tN, N); - } - } + float8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -106,79 +133,77 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float split_barrier_wait(); - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - store_c_rowmajor_fp32_m8_nx(C, sum[mm][nn], m + mm * tM, n + nn * tN, N); + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } #ifdef cl_intel_subgroup_extended_block_read -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, SGS_PER_WG, 1))) +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int M, int K, int m, int k, float4 aData[KK][MM]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + } + } +} + +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global float* B, int tN, int K, int N, int k, int n, float8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM * MM; const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG, tM, MM); - const int n = get_group_id(0) * tN * NN; - - // Initial prefetch: - const int init_k = 0; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, init_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, init_k + kk * tK, n + nn * tN, N); - } + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; } - float8 sum[MM][NN]; + float8 sum[NN][MM]; for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - sum[mm][nn] = 0; + sum[nn][mm] = 0; } } split_barrier_arrive(); for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: - const int next_k = k + tK * KK; - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, next_k + kk * tK, K); - } - } - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, next_k + kk * tK, n + nn * tN, N); - } - } + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; float4 aData[KK][MM]; - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); - } - } + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); - float8 bData[KK][NN]; - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[kk][nn] = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); - } - } + float8 bData[NN][KK]; + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { - sum[mm][nn] = mat_mul_sg16(aData[kk][mm], bData[kk][nn], sum[mm][nn]); + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); } } } @@ -191,7 +216,8 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[mm][nn])); + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl index 67f0b242..a0f73eb3 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -17,6 +17,7 @@ kernel void tf32_naive(global float* C, global float* A, global float* B, int K) sum = fma(A[m * K + k], B[k * N + n], sum); } + sum = activation(sum); C[m * N + n] = sum; } @@ -30,11 +31,12 @@ kernel void tf32_naive(global float* C, global float* A, global float* B, int K) __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float sum = 0; for (int k = 0; k < K; k += tK) { @@ -43,17 +45,19 @@ kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global f sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float2 sum = 0; for (int k = 0; k < K; k += tK) { @@ -62,17 +66,19 @@ kernel void tf32_dpas_rowmajor_m2_n16(global float* C, global float* A, global f sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float4 sum = 0; for (int k = 0; k < K; k += tK) { @@ -81,17 +87,19 @@ kernel void tf32_dpas_rowmajor_m4_n16(global float* C, global float* A, global f sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float8 sum = 0; for (int k = 0; k < K; k += tK) { @@ -100,6 +108,7 @@ kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global f sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); } @@ -108,6 +117,7 @@ kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global f __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 1; const int tN = 16; const int M = get_global_size(1); @@ -122,12 +132,14 @@ kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 2; const int tN = 16; const int M = get_global_size(1) * tM; @@ -142,12 +154,14 @@ kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 4; const int tN = 16; const int M = get_global_size(1) * tM; @@ -162,12 +176,14 @@ kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A, global float* B, int K) { + __builtin_assume(K > 0); // Always at least one K iteration. const int tM = 8; const int tN = 16; const int M = get_global_size(1) * tM; @@ -182,6 +198,7 @@ kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A sum = mat_mul_sg16(aData, bData, sum); } + sum = activation(sum); intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } From bb1195313735b9a5a099c2cca8d33cacff0c8ef2 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 10 Apr 2024 11:46:59 -0700 Subject: [PATCH 74/99] initial int8 matrixperf sample --- samples/99_matrixexperimentsi8/CMakeLists.txt | 11 + samples/99_matrixexperimentsi8/main.cpp | 908 +++++++++++++++++ .../matrix_helpers_i8.cl | 914 ++++++++++++++++++ .../matrix_kernel_tiled_i8.cl | 752 ++++++++++++++ .../matrix_kernels_i8.cl | 613 ++++++++++++ samples/CMakeLists.txt | 1 + 6 files changed, 3199 insertions(+) create mode 100644 samples/99_matrixexperimentsi8/CMakeLists.txt create mode 100644 samples/99_matrixexperimentsi8/main.cpp create mode 100644 samples/99_matrixexperimentsi8/matrix_helpers_i8.cl create mode 100644 samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl create mode 100644 samples/99_matrixexperimentsi8/matrix_kernels_i8.cl diff --git a/samples/99_matrixexperimentsi8/CMakeLists.txt b/samples/99_matrixexperimentsi8/CMakeLists.txt new file mode 100644 index 00000000..b97f9c74 --- /dev/null +++ b/samples/99_matrixexperimentsi8/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2019-2024 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 99 + TARGET matrixexperimentsi8 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl matrix_kernel_tiled_i8.cl) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/99_matrixexperimentsi8/main.cpp new file mode 100644 index 00000000..ff99d3c4 --- /dev/null +++ b/samples/99_matrixexperimentsi8/main.cpp @@ -0,0 +1,908 @@ +/* +// Copyright (c) 2019-2024 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "util.hpp" + +using test_clock = std::chrono::high_resolution_clock; + +bool zeroData = false; +bool identityData = false; +bool fixedData = false; +bool validate = false; +bool emulate = false; +bool wallclock = false; +bool skipinit = false; +bool roundRobin = false; +int testIterations = 16; +float threshold = 0.01f; + +std::string makeTestName( + const std::string &func, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +std::string makeTestName( + const std::string &func, + int tM, int tN, + int MM, int NN, + size_t M, size_t N, size_t K) +{ + std::ostringstream ret; + ret << func; + ret << ""; + ret << " (M=" << M << ", N=" << N << ", K=" << K << ")"; + return ret.str(); +} + +static size_t findMinSubGroupSize(cl::Device& device) +{ + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } + return 0; +} + +static void setRoundRobin(cl::Kernel& kernel) +{ + constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; + constexpr cl_uint CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL = 0x10023; + const cl_uint policy = CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_ROUND_ROBIN_INTEL; + clSetKernelExecInfo( + kernel(), + CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL, + sizeof(policy), + &policy); +} + +template +static void fill_matrix(std::vector& M, size_t numRows, size_t numCols) +{ + if (zeroData) { + std::generate(std::begin(M), std::end(M), [&]{ return 0; }); + } + else if (identityData) { + std::generate(std::begin(M), std::end(M), [&]{ return 1; }); + } else if (fixedData) { + for (size_t r = 0; r < numRows; r++) { + for (size_t c = 0; c < numCols; c++) { + M[r * numCols + c] = static_cast(r + c); + } + } + } else { + std::random_device dev; + std::mt19937 rng(dev()); + std::uniform_int_distribution dist(-64, 64); + std::generate(std::begin(M), std::end(M), [&]{ return dist(rng); }); + } +} + +template +static void vnni_matrix( + std::vector &dst, const std::vector &src, + size_t numRows, size_t numCols, size_t factor) +{ + for (size_t r = 0; r < numRows / factor; r++) { + for (size_t c = 0; c < numCols; c++) { + for (size_t k = 0; k < factor; k++) { + dst[r * numCols * factor + c * factor + k] = + src[(r * factor + k) * numCols + c]; + } + } + } +} + +template +static void compute_reference( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + C[m * N + n] = sum; + } + } +} + +template +void check_results( + size_t M, + size_t N, + const std::vector& C, + const std::vector& C_ref) +{ + float err = 0.f; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + auto index = m * N + n; + if (C[index] != C_ref[index]) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": Wanted " + << C_ref[index] << ", got " << C[index] << std::endl; + return; + } + } + } +} + +static float hw_time(cl::Event& event) +{ + auto ns = event.getProfilingInfo() - + event.getProfilingInfo(); + return ns / 1e9f; +} + +static void i8_naive( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "i8_naive"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_vnni( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_vnni"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_vnni_tiled( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_vnni_tiled"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + kernelName += "_" + std::to_string(MM); + kernelName += "x" + std::to_string(NN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (tM * MM > M) { + printf("M is too small.\n"); + } else if (tN * NN > N) { + printf("N is too small.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +int main(int argc, char** argv) +{ + int platformIndex = 0; + int deviceIndex = 0; + + std::string fileName("matrix_kernels_i8.cl"); + std::string buildOptions; + size_t matrixSize = 512; + + size_t mask = ~0; + + { + popl::OptionParser op("Supported Options"); + op.add>("p", "platform", "Platform Index", platformIndex, &platformIndex); + op.add>("d", "device", "Device Index", deviceIndex, &deviceIndex); + op.add>("", "file", "Kernel File Name", fileName, &fileName); + op.add>("", "options", "Program Build Options", buildOptions, &buildOptions); + op.add>("m", "matrixsize", "Matrix Size", matrixSize, &matrixSize); + op.add>("i", "iterations", "Test Iterations", testIterations, &testIterations); + op.add("", "validate", "Validate Results", &validate); + op.add("", "zero", "Use Zero Data", &zeroData); + op.add("", "identity", "Use Identity Data", &identityData); + op.add("", "fixed", "Use Fixed Data", &fixedData); + op.add("", "emulate", "Unconditionally Emulate dpas", &emulate); + op.add("", "wallclock", "Measure Wallclock Time", &wallclock); + op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); + op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); + op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); + op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); + bool printUsage = false; + try { + op.parse(argc, argv); + } catch (std::exception& e) { + fprintf(stderr, "Error: %s\n\n", e.what()); + printUsage = true; + } + + if (printUsage || !op.unknown_options().empty() || !op.non_option_args().empty()) { + fprintf(stderr, + "Usage: matrixexperimentsi8 [options]\n" + "%s", op.help().c_str()); + return -1; + } + } + + std::vector platforms; + cl::Platform::get(&platforms); + if (platformIndex >= platforms.size()) { + printf("Requested platform index is %d, but only %zu platforms were found.\n", + platformIndex, platforms.size()); + return -1; + } + + printf("Running on platform: %s\n", + platforms[platformIndex].getInfo().c_str() ); + + std::vector devices; + platforms[platformIndex].getDevices(CL_DEVICE_TYPE_ALL, &devices); + if (deviceIndex >= devices.size()) { + printf("Requested device index is %d, but only %zu devices were found.\n", + deviceIndex, devices.size()); + } + + cl::Device& device = devices[deviceIndex]; + printf("Running on device: %s (%uCUs, %uMHz)\n", + device.getInfo().c_str(), + device.getInfo(), + device.getInfo()); + printf("Running on drivers: %s\n", + device.getInfo().c_str()); + + auto minSubGroupSize = findMinSubGroupSize(device); + + bool has_simd8 = minSubGroupSize == 8; + bool emulate_tN8 = true; + bool emulate_tN16 = true; + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + switch(minSubGroupSize) { + case 8: emulate_tN8 = false; break; + case 16: emulate_tN16 = false; break; + default: break; + } + } + + buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); + buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); + buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); + + printf("Config:\n"); + printf("\tTest Iterations: %d\n", testIterations); + printf("\tValidating data?: %s\n", validate ? "true" : "false"); + printf("\tFixed data?: %s\n", fixedData ? "true" : "false"); + printf("\tWallclock time?: %s\n", wallclock ? "true" : "false"); + printf("\tEmulate dpas for tN=8?: %s\n", emulate_tN8 ? "true" : "false"); + printf("\tEmulate dpas for tN=16?: %s\n", emulate_tN16 ? "true" : "false"); + + cl::Context context{device}; + cl::CommandQueue queue{context, device, CL_QUEUE_PROFILING_ENABLE}; + + printf("Reading program source from file: %s\n", fileName.c_str() ); + std::string kernelString = readStringFromFile(fileName.c_str()); + + printf("Building program with build options: %s\n", + buildOptions.empty() ? "(none)" : buildOptions.c_str() ); + cl::Program program{ context, kernelString }; + program.build(buildOptions.c_str()); + for( auto& device : program.getInfo() ) + { + printf("Program build log for device %s:\n", + device.getInfo().c_str() ); + printf("%s\n", + program.getBuildInfo(device).c_str() ); + } + + const auto M = matrixSize; + const auto N = matrixSize; + const auto K = matrixSize; + + std::vector A_vec(M * K); + std::vector B_vec(K * N); + std::vector Bvnni_vec(K * N); + + std::vector C_ref(M * N); + + printf("Initializing source matrices...\n"); + fill_matrix(A_vec, M, K); + fill_matrix(B_vec, K, N); + + vnni_matrix(Bvnni_vec, B_vec, K, N, 4); + + if (validate) { + printf("Computing reference...\n"); + compute_reference(C_ref, A_vec, B_vec, M, N, K); + } + + printf("Creating source buffers...\n"); + cl::Buffer A{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, A_vec.size() * sizeof(A_vec[0]), A_vec.data()}; + cl::Buffer B{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, B_vec.size() * sizeof(B_vec[0]), B_vec.data()}; + cl::Buffer Bvnni{context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, Bvnni_vec.size() * sizeof(Bvnni_vec[0]), Bvnni_vec.data()}; + cl::Buffer C{context, CL_MEM_WRITE_ONLY, C_ref.size() * sizeof(C_ref[0])}; + + printf("Running tests...\n"); + + if (mask & 0x1) { + i8_naive(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x2) { + i8_dpas_rowmajor<1, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x4) { + i8_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x8) { + i8_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x10) { + i8_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x20) { + i8_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x40) { + i8_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x80) { + i8_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x100) { + i8_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x200) { + i8_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<4, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x400) { + i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); + } + + if (mask & 0x800) { + i8_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<4, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + if (mask & 0x1000) { + i8_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + i8_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); + } + + printf("Done.\n"); + + return 0; +} diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl new file mode 100644 index 00000000..1d591217 --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl @@ -0,0 +1,914 @@ +__attribute__((overloadable)) +int activation(int i) +{ +#if defined(ACTIVATION_RELU) + return max(i, 0); +#else // identity + return i; +#endif +} + +__attribute__((overloadable)) +int2 activation(int2 i) +{ + int2 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + return res; +} + +__attribute__((overloadable)) +int4 activation(int4 i) +{ + int4 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + return res; +} + +int8 activation(int8 i) +{ + int8 res; + res.s0 = activation(i.s0); + res.s1 = activation(i.s1); + res.s2 = activation(i.s2); + res.s3 = activation(i.s3); + res.s4 = activation(i.s4); + res.s5 = activation(i.s5); + res.s6 = activation(i.s6); + res.s7 = activation(i.s7); + return res; +} + +#ifndef __has_builtin +#define __has_builtin(x) 0 +#endif +#if __has_builtin(__builtin_expect) == 0 +#define __builtin_expect(x) +#endif + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) + +typedef global char* global_aligned_char_ptr __attribute__((align_value(4))); + +inline int compute_m(const int num_sgs_x, const int num_sgs_y, const int tM, const int MM) +{ + const int m_start = get_group_id(1) * num_sgs_y; + const int m_index = num_sgs_y > 1 ? m_start + get_sub_group_id() / num_sgs_x : m_start; + return m_index * tM * MM; +} + +inline int compute_n(const int num_sgs_x, const int num_sgs_y, const int tN, const int NN) +{ + const int n_start = get_group_id(0) * num_sgs_x; + const int n_index = num_sgs_x > 1 ? n_start + get_sub_group_id() % num_sgs_x : n_start; + return n_index * tN * NN; +} + +// Emulated SIMD8 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc) +{ + int res = acc; + + // TODO: this could use integer dot products instead? + + res = as_char4(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; + res = as_char4(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; + res = as_char4(sub_group_broadcast(a, 0)).z * as_char4(b.s0).z + res; + res = as_char4(sub_group_broadcast(a, 0)).w * as_char4(b.s0).w + res; + + res = as_char4(sub_group_broadcast(a, 1)).x * as_char4(b.s1).x + res; + res = as_char4(sub_group_broadcast(a, 1)).y * as_char4(b.s1).y + res; + res = as_char4(sub_group_broadcast(a, 1)).z * as_char4(b.s1).z + res; + res = as_char4(sub_group_broadcast(a, 1)).w * as_char4(b.s1).w + res; + + res = as_char4(sub_group_broadcast(a, 2)).x * as_char4(b.s2).x + res; + res = as_char4(sub_group_broadcast(a, 2)).y * as_char4(b.s2).y + res; + res = as_char4(sub_group_broadcast(a, 2)).z * as_char4(b.s2).z + res; + res = as_char4(sub_group_broadcast(a, 2)).w * as_char4(b.s2).w + res; + + res = as_char4(sub_group_broadcast(a, 3)).x * as_char4(b.s3).x + res; + res = as_char4(sub_group_broadcast(a, 3)).y * as_char4(b.s3).y + res; + res = as_char4(sub_group_broadcast(a, 3)).z * as_char4(b.s3).z + res; + res = as_char4(sub_group_broadcast(a, 3)).w * as_char4(b.s3).w + res; + + res = as_char4(sub_group_broadcast(a, 4)).x * as_char4(b.s4).x + res; + res = as_char4(sub_group_broadcast(a, 4)).y * as_char4(b.s4).y + res; + res = as_char4(sub_group_broadcast(a, 4)).z * as_char4(b.s4).z + res; + res = as_char4(sub_group_broadcast(a, 4)).w * as_char4(b.s4).w + res; + + res = as_char4(sub_group_broadcast(a, 5)).x * as_char4(b.s5).x + res; + res = as_char4(sub_group_broadcast(a, 5)).y * as_char4(b.s5).y + res; + res = as_char4(sub_group_broadcast(a, 5)).z * as_char4(b.s5).z + res; + res = as_char4(sub_group_broadcast(a, 5)).w * as_char4(b.s5).w + res; + + res = as_char4(sub_group_broadcast(a, 6)).x * as_char4(b.s6).x + res; + res = as_char4(sub_group_broadcast(a, 6)).y * as_char4(b.s6).y + res; + res = as_char4(sub_group_broadcast(a, 6)).z * as_char4(b.s6).z + res; + res = as_char4(sub_group_broadcast(a, 6)).w * as_char4(b.s6).w + res; + + res = as_char4(sub_group_broadcast(a, 7)).x * as_char4(b.s7).x + res; + res = as_char4(sub_group_broadcast(a, 7)).y * as_char4(b.s7).y + res; + res = as_char4(sub_group_broadcast(a, 7)).z * as_char4(b.s7).z + res; + res = as_char4(sub_group_broadcast(a, 7)).w * as_char4(b.s7).w + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(int2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(int4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// Emulated SIMD16 dpas: +__attribute__((overloadable)) +int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) +{ + float res = acc; + + res = as_char2(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; + res = as_char2(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; + res = as_char2(sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res; + res = as_char2(sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res; + + res = as_char2(sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res; + res = as_char2(sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res; + res = as_char2(sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res; + res = as_char2(sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res; + + res = as_char2(sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res; + res = as_char2(sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res; + res = as_char2(sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res; + res = as_char2(sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res; + + res = as_char2(sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res; + res = as_char2(sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res; + res = as_char2(sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res; + res = as_char2(sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res; + + res = as_char2(sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res; + res = as_char2(sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res; + res = as_char2(sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res; + res = as_char2(sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res; + + res = as_char2(sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res; + res = as_char2(sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res; + res = as_char2(sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res; + res = as_char2(sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res; + + res = as_char2(sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res; + res = as_char2(sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res; + res = as_char2(sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res; + res = as_char2(sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res; + + res = as_char2(sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res; + res = as_char2(sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res; + res = as_char2(sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res; + res = as_char2(sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res; + + return res; +} + +__attribute__((overloadable)) +int2 emu_sub_group_i8_i8_matrix_mad_k32(short2 a, int8 b, int2 acc) +{ + int2 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + + return res; +} + +__attribute__((overloadable)) +int4 emu_sub_group_i8_i8_matrix_mad_k32(short4 a, int8 b, int4 acc) +{ + int4 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + + return res; +} + +__attribute__((overloadable)) +int8 emu_sub_group_i8_i8_matrix_mad_k32(short8 a, int8 b, int8 acc) +{ + int8 res; + + res.s0 = emu_sub_group_i8_i8_matrix_mad_k32(a.s0, b, acc.s0); + res.s1 = emu_sub_group_i8_i8_matrix_mad_k32(a.s1, b, acc.s1); + res.s2 = emu_sub_group_i8_i8_matrix_mad_k32(a.s2, b, acc.s2); + res.s3 = emu_sub_group_i8_i8_matrix_mad_k32(a.s3, b, acc.s3); + res.s4 = emu_sub_group_i8_i8_matrix_mad_k32(a.s4, b, acc.s4); + res.s5 = emu_sub_group_i8_i8_matrix_mad_k32(a.s5, b, acc.s5); + res.s6 = emu_sub_group_i8_i8_matrix_mad_k32(a.s6, b, acc.s6); + res.s7 = emu_sub_group_i8_i8_matrix_mad_k32(a.s7, b, acc.s7); + + return res; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int load_a_rowmajor_d8_m1_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + ret = intel_sub_group_block_read(A_ui + offset_ui); + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int2 load_a_rowmajor_d8_m2_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int2 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int4 load_a_rowmajor_d8_m4_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int4 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +// M rows x K columns +// This is the SIMD8 version, where each work-item loads four values. +int8 load_a_rowmajor_d8_m8_k32_sg8(global char* A, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 4 + colStart / 4; + + ret.s0 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s1 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s2 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s3 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s4 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s5 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s6 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + ret.s7 = intel_sub_group_block_read(A_ui + offset_ui); offset_ui += stride / 4; + + return ret; +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD8 version, where each work-item loads two values. +// The first tile is returned the first components of the return value, the the next tile, etc. +int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ + uint16 ret; + + global uint* A_ui = (global uint*)A; + uint offset_ui = rowStart * stride / 2 + colStart / 2; + + ret.s08 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s19 = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s2a = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s3b = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s4c = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s5d = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s6e = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + ret.s7f = intel_sub_group_block_read2(A_ui + offset_ui); offset_ui += stride / 2; + + return as_int16(ret); +} + +// M rows x K columns x V tiles (in the K dimension) +void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short load_a_rowmajor_d8_m1_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret = intel_sub_group_block_read_us(A_us + offset_us); + + return as_short(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short2 load_a_rowmajor_d8_m2_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort2 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short2(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short4 load_a_rowmajor_d8_m4_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort4 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short4(ret); +} + +// M rows x K columns +// This is the SIMD16 version, where each work-item loads two values. +short8 load_a_rowmajor_d8_m8_k32_sg16(global char* A, int rowStart, int colStart, int stride) +{ + ushort8 ret; + + global ushort* A_us = (global ushort*)A; + uint offset_us = rowStart * stride / 2 + colStart / 2; + + ret.s0 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s1 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s2 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s3 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s4 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s5 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s6 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + ret.s7 = intel_sub_group_block_read_us(A_us + offset_us); offset_us += stride / 2; + + return as_short8(ret); +} + +#if 0 + +// M rows x K columns x V tiles (in the K dimension) +// This is the SIMD16 version, where each work-item loads one value. +// The first tile is returned the first components of the return value, the the next tile, etc. +short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ + ushort16 ret; + + uint offset = rowStart * stride + colStart; + ret.s08 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s19 = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s2a = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s3b = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s4c = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s5d = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s6e = intel_sub_group_block_read_us2(A + offset); offset += stride; + ret.s7f = intel_sub_group_block_read_us2(A + offset); offset += stride; + + return as_short16(ret); +} + +// M rows x K columns x V tiles (in the M and K dimensions) +void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(A + offset) % 4 == 0); + prefetch(A + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +// K rows x N columns: +// Each work-item loads K values and converts to VNNI. +// Stride is in units of elements. +int8 load_b_rowmajor_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uchar* B_uc = (global uchar*)B; + uint offset = rowStart * stride + colStart; + + uchar row0 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row1 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row2 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row3 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row4 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row5 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row6 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row7 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row8 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row9 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row10 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row11 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row12 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row13 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row14 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row15 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row16 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row17 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row18 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row19 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row20 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row21 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row22 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row23 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row24 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row25 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row26 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row27 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row28 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row29 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row30 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + uchar row31 = intel_sub_group_block_read_uc(B_uc + offset); offset += stride; + + ret.s0 = as_int((uchar4)(row0, row1, row2, row3)); + ret.s1 = as_int((uchar4)(row4, row5, row6, row7)); + ret.s2 = as_int((uchar4)(row8, row9, row10, row11)); + ret.s3 = as_int((uchar4)(row12, row13, row14, row15)); + ret.s4 = as_int((uchar4)(row16, row17, row18, row19)); + ret.s5 = as_int((uchar4)(row20, row21, row22, row23)); + ret.s6 = as_int((uchar4)(row24, row25, row26, row27)); + ret.s7 = as_int((uchar4)(row28, row29, row30, row31)); + + return ret; +} + +// K rows x N columns: +// Each work-item loads K values that has already been converted to VNNI. +// Stride is in units of elements. +int8 load_b_vnni_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +{ + int8 ret; + + global uint* B_ui = (global uint*)B; + uint offset_ui = rowStart / 4 * stride + colStart; + + ret.s0 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s1 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s2 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s3 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s4 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s5 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s6 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + ret.s7 = intel_sub_group_block_read(B_ui + offset_ui); offset_ui += stride; + + return ret; +} + +#if 0 + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); offset += 8 * stride; +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B + offset) % 4 == 0); + prefetch(B + offset, 2); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the N dimension) +void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +// K rows x N columns x V tiles (in the K dimension) +void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colStart, int stride) +{ +#if defined(PREFETCH_DEFAULT) + global uint* B_ui = (global uint*)B; + uint offset_ui = colStart + (rowStart / 2 + get_sub_group_local_id()) * stride; + __builtin_assume((ulong)(B_ui + offset_ui) % 4 == 0); + prefetch(B_ui + offset_ui, 1); +#endif // defined(PREFETCH_DEFAULT) +} + +#endif + +void store_c_rowmajor_int32_m1_nx(global int* C, int v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint v_ui = as_uint(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; +} + +void store_c_rowmajor_int32_m2_nx(global int* C, int2 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint2 v_ui = as_uint2(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; +} + +void store_c_rowmajor_int32_m4_nx(global int* C, int4 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint4 v_ui = as_uint4(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; +} + +void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colStart, int stride) +{ + global uint* C_ui = (global uint*)C; + uint8 v_ui = as_uint8(v); + + uint offset = rowStart * stride + colStart; + + intel_sub_group_block_write(C_ui + offset, v_ui.s0); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s2); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s4); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s5); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s6); offset += stride; + intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; +} + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) + +#if 0 +#ifdef cl_intel_subgroup_extended_block_read + +// Note for 2D block reads: +// - the tile width and height is encoded into the function name. +// - base_address is the byte address. Must be 64B aligned. +// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. +// - height is the height of the entire matrix, or equivalently the number of rows. +// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. +// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. + +// Built-in functions are: + +// #ifdef cl_intel_subgroup_extended_block_read +// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); +// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); +// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); +// #endif //defined(cl_intel_subgroup_extended_block_read) + + +// For intrinsics, the pattern is: +// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat +// - operation (optional): _transpose or _transform +// - for no transpose or transform: +// - type / elements size: _u8 or _u16 or _u32 or _u64 +// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 +// - tile width: _k64 or _k32 or _k16 or _k8 +// - number of tiles: _v2 or _v1 +// - for transpose: +// - type / element size: _u64 or _u32 +// - number of tile rows: subgroup size (16) +// - tile width: _k4 (for _u64) or _k8 (for _u32) +// - number of tiles: 1 +// - for transform: +// - type / element size: _u16 or _u8 +// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) +// - tile width: subgroup size (16) +// - number of tiles: 1 + +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + +typedef ushort __attribute__((ext_vector_type(32))) ushort32; +typedef ushort __attribute__((ext_vector_type(64))) ushort64; + +typedef uint __attribute__((ext_vector_type(32))) uint32; + +// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: + +ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + + +void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + +void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); + + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); +void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); + +ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0] = tmp.lo.lo; + dst[1] = tmp.lo.hi; + dst[2] = tmp.hi.lo; + dst[3] = tmp.hi.hi; +} + +void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) +{ + ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo; + dst[0][1] = tmp.lo.hi; + dst[1][0] = tmp.hi.lo; + dst[1][1] = tmp.hi.hi; +} +void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) +{ + ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = tmp.lo.lo.lo; + dst[0][1] = tmp.lo.lo.hi; + dst[0][2] = tmp.lo.hi.lo; + dst[0][3] = tmp.lo.hi.hi; + dst[1][0] = tmp.hi.lo.lo; + dst[1][1] = tmp.hi.lo.hi; + dst[1][2] = tmp.hi.hi.lo; + dst[1][3] = tmp.hi.hi.hi; +} + +uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +// Each block is K rows x N columns, where the K rows have been VNNI transformed. +int8 intel_subgroup_block_read_transform_u16_k16n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: + return as_int8(intel_subgroup_block_read_transform_u16_k16(base_address, width, height, pitch, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k32n16(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +int16 intel_subgroup_block_read_transform_u16_k16n16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); +} +void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) +{ + uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + dst[0][0] = as_int8(tmp.lo.lo); + dst[0][1] = as_int8(tmp.lo.hi); + dst[1][0] = as_int8(tmp.hi.lo); + dst[1][1] = as_int8(tmp.hi.hi); +} + + +#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C + +void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} + + +void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} +void intel_subgroup_block_write_u32_m16k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) +{ + __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); +} + +#endif // cl_intel_subgroup_extended_block_read +#endif diff --git a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl new file mode 100644 index 00000000..a862744f --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl @@ -0,0 +1,752 @@ +#error "Needs to be updated!" + +#if !defined(tK) +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#endif + +#if !defined(MM) +#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." +#endif + +#if !defined(NN) +#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." +#endif + +#if !defined(KK) +#define KK 1 +#endif + +#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) +#if !defined(cl_intel_split_work_group_barrier) +#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" +#endif +#define split_barrier_arrive() +#define split_barrier_wait() +#else +#define split_barrier_arrive() intel_work_group_barrier_arrive(0) +#define split_barrier_wait() intel_work_group_barrier_wait(0) +#endif + +#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN +#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) + +#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN +#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) + +#if !defined(SGS_PER_WG_X) +#define SGS_PER_WG_X 1 +#endif + +#if !defined(SGS_PER_WG_Y) +#define SGS_PER_WG_Y 4 +#endif + +#if !defined(PREFETCH_DISTANCE) +#define PREFETCH_DISTANCE 1 +#endif + +void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + } + } +} + +#if HAS_SIMD8 + +void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=4) { + prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + int8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#endif // HAS_SIMD8 + +void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + } + } +} + +void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +{ + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + } + } +} + +void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + // Initial prefetch: + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + // Next prefetch: + // TODO: skip prefetch on the last iterations. + HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); + HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + prefetch_k += tK * KK; + + short8 aData[KK][MM]; + HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); + + int8 bData[NN][KK]; + HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + } + } +} + +#ifdef cl_intel_subgroup_extended_block_read + +void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) +{ + if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + ushort8 tmp[2][4]; + intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + ushort8 tmp[2][2]; + intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tkk = 0; tkk < 2; tkk++) { + for (int tmm = 0; tmm < 2; tmm++) { + aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + } + } + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + ushort8 tmp[4]; + intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + for (int tmm = 0; tmm < 4; tmm++) { + aData[kk][mm + tmm] = as_short8(tmp[tmm]); + } + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + } + } + } +} + +void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn+=2) { + //if (get_sub_group_local_id() == 0) { + // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + int8 tmp[2][2]; + intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); + for (int tnn = 0; tnn < 2; tnn++) { + for (int tkk = 0; tkk < 2; tkk++) { + bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; + } + } + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn + 0][kk] = bTemp.lo; + bData[nn + 1][kk] = bTemp.hi; + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +{ + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[nn][kk + 0] = bTemp.lo; + bData[nn][kk + 1] = bTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + bData[nn][kk] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + } + } + } +} + +void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) +{ + if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { + const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) + const int kk = 0; + const int mm = sg_index_x % 4; + //if (get_sub_group_local_id() == 0) { + // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } else if (KK % 2 == 0 & MM % 4 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0 & MM % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else if (MM % 4 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm+=4) { + intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... + const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... + //if (get_sub_group_local_id() == 0) { + // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); + //} + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } else if (KK % 2 == 0 & NN % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn += 2) { + intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (NN % 2 == 0) { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn+=2) { + intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + } + } + } +} + +void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +{ + if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { + const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) + const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 + const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } else if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + } + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + } + } +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) +kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM * MM; + const int N = get_global_size(0) * NN; + const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); + const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); + + int prefetch_k = 0; + for (int p = 0; p < PREFETCH_DISTANCE; p++) { + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + } + + float8 sum[NN][MM]; + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = 0; + } + } + + split_barrier_arrive(); + + for (int k = 0; k < K; k += tK * KK) { + int8 bData[NN][KK]; + HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); + + short8 aData[KK][MM]; + HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); + + // TODO: skip prefetch on the last iterations. + HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); + prefetch_k += tK * KK; + + for (int kk = 0; kk < KK; kk++) { + for (int nn = 0; nn < NN; nn++) { + for (int mm = 0; mm < MM; mm++) { + sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); + } + } + } + + split_barrier_wait(); + split_barrier_arrive(); + } + + split_barrier_wait(); + + for (int mm = 0; mm < MM; mm++) { + for (int nn = 0; nn < NN; nn++) { + sum[nn][mm] = activation(sum[nn][mm]); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + } + } +} + +#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl new file mode 100644 index 00000000..6e27d8d1 --- /dev/null +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -0,0 +1,613 @@ +#include "matrix_helpers_i8.cl" + +#if EMULATE_tN8 +#define mat_mul_sg8 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg8 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +#if EMULATE_tN16 +#define mat_mul_sg16 emu_sub_group_i8_i8_matrix_mad_k32 +#else +#define mat_mul_sg16 intel_sub_group_i8_i8_matrix_mad_k32 +#endif + +kernel void i8_naive(global int* C, global char* A, global char* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + int sum = 0; + for (int k = 0; k < K; k++) { + sum = A[m * K + k] * B[k * N + n] + sum; + } + + sum = activation(sum); + C[m * N + n] = sum; +} + +// For all i8 kernels tK == 32: +#define tK 32 + +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) + +#if HAS_SIMD8 + +// rowmajor kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m1_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m2_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m4_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) +kernel void i8_dpas_vnni_m8_n8(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 8; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg8(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#endif // HAS_SIMD8 + +// rowmajor krenels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * get_local_size(0); + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +// vnni kernels: + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m1_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m2_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m4_nx(C, sum, m, n, N); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); + int8 bData = load_b_vnni_d8_k32_nx(B, k, n, N); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); +} + +#if 0 + +#ifdef cl_intel_subgroup_extended_block_read + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1); + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 1; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int sum = 0; + for (int k = 0; k < K; k += tK) { + short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 2; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int2 sum = 0; + for (int k = 0; k < K; k += tK) { + short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); +} + +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 8; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + + int8 sum = 0; + for (int k = 0; k < K; k += tK) { + short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); +} + +#endif // cl_intel_subgroup_extended_block_read + +// Tiled matrix multiplication kernels, generated from a template: + +#define MM 1 +#define NN 1 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 1 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 1 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 2 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 2 +#define NN 4 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#define MM 4 +#define NN 4 +#include "matrix_kernel_tiled_i8.cl" +#undef MM +#undef NN + +#endif // disabling these cases for now + +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) + +#undef tK diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 30f877b4..af509b28 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -90,4 +90,5 @@ if(BUILD_EXTENSION_SAMPLES) endif() add_subdirectory( 99_matrixexperiments ) +add_subdirectory( 99_matrixexperimentsi8 ) add_subdirectory( 99_matrixexperimentstf32 ) From bb48fd488abf8cdf42fbe1330fa943ff347f694b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Apr 2024 08:31:43 -0700 Subject: [PATCH 75/99] enable more int8 samples --- .../matrix_helpers_i8.cl | 203 ++---------------- .../matrix_kernel_tiled_i8.cl | 2 - .../matrix_kernels_i8.cl | 38 ++-- 3 files changed, 36 insertions(+), 207 deletions(-) diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl index 1d591217..26c916c4 100644 --- a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl @@ -630,7 +630,6 @@ void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colSt #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) -#if 0 #ifdef cl_intel_subgroup_extended_block_read // Note for 2D block reads: @@ -689,205 +688,42 @@ enum LSC_LDCC { LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached }; -typedef ushort __attribute__((ext_vector_type(32))) ushort32; -typedef ushort __attribute__((ext_vector_type(64))) ushort64; - -typedef uint __attribute__((ext_vector_type(32))) uint32; - // Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: -ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort __builtin_IB_subgroup_block_read_flat_u8_m1k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u8_m2k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u8_m4k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u8_m8k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - - -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); -void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); - -ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort intel_subgroup_block_read_u8_m1k32(const __global void *base_address, int width, int height, int pitch, int2 coord) { - return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + return __builtin_IB_subgroup_block_read_flat_u8_m1k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort2 intel_subgroup_block_read_u8_m2k32(const __global void *base_address, int width, int height, int pitch, int2 coord) { - return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + return __builtin_IB_subgroup_block_read_flat_u8_m2k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort4 intel_subgroup_block_read_u8_m4k32(const __global void *base_address, int width, int height, int pitch, int2 coord) { - return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + return __builtin_IB_subgroup_block_read_flat_u8_m4k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort8 intel_subgroup_block_read_u8_m8k32(const __global void *base_address, int width, int height, int pitch, int2 coord) { - return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) -{ - ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0] = tmp.lo.lo; - dst[1] = tmp.lo.hi; - dst[2] = tmp.hi.lo; - dst[3] = tmp.hi.hi; + return __builtin_IB_subgroup_block_read_flat_u8_m8k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) -{ - ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = tmp.lo.lo; - dst[0][1] = tmp.lo.hi; - dst[1][0] = tmp.hi.lo; - dst[1][1] = tmp.hi.hi; -} -void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) -{ - ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = tmp.lo.lo.lo; - dst[0][1] = tmp.lo.lo.hi; - dst[0][2] = tmp.lo.hi.lo; - dst[0][3] = tmp.lo.hi.hi; - dst[1][0] = tmp.hi.lo.lo; - dst[1][1] = tmp.hi.lo.hi; - dst[1][2] = tmp.hi.hi.lo; - dst[1][3] = tmp.hi.hi.hi; -} - -uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +uint8 intel_subgroup_block_read_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - -// Each block is K rows x N columns, where the K rows have been VNNI transformed. -int8 intel_subgroup_block_read_transform_u16_k16n16(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: - return as_int8(intel_subgroup_block_read_transform_u16_k16(base_address, width, height, pitch, coord)); -} -int16 intel_subgroup_block_read_transform_u16_k32n16(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); -} -int16 intel_subgroup_block_read_transform_u16_k16n16v2(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); -} -void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) -{ - uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = as_int8(tmp.lo.lo); - dst[0][1] = as_int8(tmp.lo.hi); - dst[1][0] = as_int8(tmp.hi.lo); - dst[1][1] = as_int8(tmp.hi.hi); -} - -#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C - -void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { @@ -905,10 +741,5 @@ void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width { __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m16k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} #endif // cl_intel_subgroup_extended_block_read -#endif diff --git a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl index a862744f..75940825 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl @@ -1,5 +1,3 @@ -#error "Needs to be updated!" - #if !defined(tK) #error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." #endif diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index 6e27d8d1..5e6f5735 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -380,8 +380,6 @@ kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, i store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); } -#if 0 - #ifdef cl_intel_subgroup_extended_block_read __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -397,8 +395,8 @@ kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, glo int sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -419,8 +417,8 @@ kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, glo int2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -441,8 +439,8 @@ kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, glo int4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -463,8 +461,8 @@ kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, glo int8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -485,8 +483,8 @@ kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global int sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); sum = mat_mul_sg16(aData, bData, sum); } @@ -507,8 +505,8 @@ kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global int2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); sum = mat_mul_sg16(aData, bData, sum); } @@ -529,8 +527,8 @@ kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global int4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); sum = mat_mul_sg16(aData, bData, sum); } @@ -551,8 +549,8 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global int8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); + int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); sum = mat_mul_sg16(aData, bData, sum); } @@ -562,6 +560,8 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global #endif // cl_intel_subgroup_extended_block_read +#if 0 // disable the tiled cases for now + // Tiled matrix multiplication kernels, generated from a template: #define MM 1 @@ -606,7 +606,7 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global #undef MM #undef NN -#endif // disabling these cases for now +#endif // disable the tiled cases for now #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) From 1b8fda8b08ca58325f9ee5e776724d6b05c46bb4 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Fri, 12 Apr 2024 10:15:21 -0700 Subject: [PATCH 76/99] slight tf32 diversion --- .../matrix_kernel_tiled_tf32.cl | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index db0bb27c..fed6a52b 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -145,9 +145,22 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int M, int K, int m, int k, float4 aData[KK][MM]) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + if (KK % 2 == 0) { + for (int kk = 0; kk < KK; kk+=2) { + for (int mm = 0; mm < MM; mm++) { + //if (get_sub_group_local_id() == 0) { + // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); + //} + float8 aTemp = as_float8(intel_subgroup_block_read_u32_m8k8v2(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + aData[kk + 0][mm] = aTemp.lo; + aData[kk + 1][mm] = aTemp.hi; + } + } + } else { + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + } } } } From 5c3a6d06c8ffe537ca8d6aeab0f5db3f1dc95dd4 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 20 May 2024 16:05:28 -0700 Subject: [PATCH 77/99] update function names to align closer with final proposal Note: this still returns the loaded data vs. loading through a pointer. --- .../99_matrixexperiments/matrix_helpers.cl | 155 +++++++++--------- .../matrix_kernel_tiled.cl | 112 ++++++------- .../99_matrixexperiments/matrix_kernels.cl | 140 ++++++++-------- 3 files changed, 206 insertions(+), 201 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index acf1219a..f1f5563b 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -212,7 +212,7 @@ float8 emu_sub_group_bf16_bf16_matrix_mad_k16(short8 a, int8 b, float8 acc) // M rows x K columns // This is the SIMD8 version, where each work-item loads two values. -int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +int load_a_rowmajor_16b_1r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) { int ret; @@ -225,7 +225,7 @@ int load_a_rowmajor_d16_m1_k16_sg8(global ushort* A, int rowStart, int colStart // M rows x K columns // This is the SIMD8 version, where each work-item loads two values. -int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +int2 load_a_rowmajor_16b_2r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) { int2 ret; @@ -240,7 +240,7 @@ int2 load_a_rowmajor_d16_m2_k16_sg8(global ushort* A, int rowStart, int colStart // M rows x K columns // This is the SIMD8 version, where each work-item loads two values. -int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +int4 load_a_rowmajor_16b_4r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) { int4 ret; @@ -257,7 +257,7 @@ int4 load_a_rowmajor_d16_m4_k16_sg8(global ushort* A, int rowStart, int colStart // M rows x K columns // This is the SIMD8 version, where each work-item loads two values. -int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart, int stride) +int8 load_a_rowmajor_16b_8r16c_sg8(global ushort* A, int rowStart, int colStart, int stride) { int8 ret; @@ -279,7 +279,7 @@ int8 load_a_rowmajor_d16_m8_k16_sg8(global ushort* A, int rowStart, int colStart // M rows x K columns x V tiles (in the K dimension) // This is the SIMD8 version, where each work-item loads two values. // The first tile is returned the first components of the return value, the the next tile, etc. -int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +int16 load_a_rowmajor_16b_8r16x2c_sg8(global ushort* A, int rowStart, int colStart, int stride) { uint16 ret; @@ -299,7 +299,7 @@ int16 load_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colSt } // M rows x K columns x V tiles (in the K dimension) -void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int colStart, int stride) +void prefetch_a_rowmajor_16b_8r16x2c_sg8(global ushort* A, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -310,7 +310,7 @@ void prefetch_a_rowmajor_d16_m8_k16v2_sg8(global ushort* A, int rowStart, int co // M rows x K columns // This is the SIMD16 version, where each work-item loads one value. -short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +short load_a_rowmajor_16b_1r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort ret; @@ -322,7 +322,7 @@ short load_a_rowmajor_d16_m1_k16_sg16(global ushort* A, int rowStart, int colSta // M rows x K columns // This is the SIMD16 version, where each work-item loads one value. -short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +short2 load_a_rowmajor_16b_2r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort2 ret; @@ -335,7 +335,7 @@ short2 load_a_rowmajor_d16_m2_k16_sg16(global ushort* A, int rowStart, int colSt // M rows x K columns // This is the SIMD16 version, where each work-item loads one value. -short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +short4 load_a_rowmajor_16b_4r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort4 ret; @@ -350,7 +350,7 @@ short4 load_a_rowmajor_d16_m4_k16_sg16(global ushort* A, int rowStart, int colSt // M rows x K columns // This is the SIMD16 version, where each work-item loads one value. -short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colStart, int stride) +short8 load_a_rowmajor_16b_8r16c_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort8 ret; @@ -370,7 +370,7 @@ short8 load_a_rowmajor_d16_m8_k16_sg16(global ushort* A, int rowStart, int colSt // M rows x K columns x V tiles (in the K dimension) // This is the SIMD16 version, where each work-item loads one value. // The first tile is returned the first components of the return value, the the next tile, etc. -short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +short16 load_a_rowmajor_16b_8r16x2c_sg16(global ushort* A, int rowStart, int colStart, int stride) { ushort16 ret; @@ -388,7 +388,7 @@ short16 load_a_rowmajor_d16_m8_k16v2_sg16(global ushort* A, int rowStart, int co } // M rows x K columns x V tiles (in the M and K dimensions) -void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int colStart, int stride) +void prefetch_a_rowmajor_16b_8x2r16x2c_sg16(global ushort* A, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -398,9 +398,9 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int } // K rows x N columns: -// Each work-item loads K values and converts to VNNI. +// Each work-item loads K values and packs into 32-bits. // Stride is in units of elements. -int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, int stride) +int8 load_b_rowmajor_16b_16rNc(global ushort* B, int rowStart, int colStart, int stride) { int8 ret; @@ -436,9 +436,9 @@ int8 load_b_rowmajor_d16_k16_nx(global ushort* B, int rowStart, int colStart, in } // K rows x N columns: -// Each work-item loads K values that has already been converted to VNNI. +// Each work-item loads K values that have already been packed into 32-bits. // Stride is in units of elements. -int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int stride) +int8 load_b_packed_16b_16rNc(global ushort* B, int rowStart, int colStart, int stride) { int8 ret; @@ -458,7 +458,7 @@ int8 load_b_vnni_d16_k16_nx(global ushort* B, int rowStart, int colStart, int st } // K rows x N columns x V tiles (in the N dimension) -void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int colStart, int stride) +void prefetch_b_rowmajor_16b_16r8x4c_sg8(global ushort* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -470,7 +470,7 @@ void prefetch_b_rowmajor_d16_k16_n8v4_sg8(global ushort* B, int rowStart, int co } // K rows x N columns x V tiles (in the N dimension) -void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int colStart, int stride) +void prefetch_b_rowmajor_16b_16r16x2c_sg16(global ushort* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -480,7 +480,7 @@ void prefetch_b_rowmajor_d16_k16_n16v2_sg16(global ushort* B, int rowStart, int } // K rows x N columns x V tiles (in the N dimension) -void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colStart, int stride) +void prefetch_b_packed_16b_16r8x2c_sg8(global ushort* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; @@ -491,7 +491,7 @@ void prefetch_b_vnni_d16_k16_n8v2_sg8(global ushort* B, int rowStart, int colSta } // K rows x N columns x V tiles (in the K dimension) -void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colStart, int stride) +void prefetch_b_packed_16b_16x2r16c_sg16(global ushort* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) global uint* B_ui = (global uint*)B; @@ -501,7 +501,7 @@ void prefetch_b_vnni_d16_k16v2_n16_sg16(global ushort* B, int rowStart, int colS #endif // defined(PREFETCH_DEFAULT) } -void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_1rNc(global float* C, float v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint v_ui = as_uint(v); @@ -511,7 +511,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; } -void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_2rNc(global float* C, float2 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint2 v_ui = as_uint2(v); @@ -522,7 +522,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; } -void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_4rNc(global float* C, float4 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint4 v_ui = as_uint4(v); @@ -535,7 +535,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; } -void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint8 v_ui = as_uint8(v); @@ -564,24 +564,6 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. // - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. -// Built-in functions are: - -// #ifdef cl_intel_subgroup_extended_block_read -// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); -// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); -// #endif //defined(cl_intel_subgroup_extended_block_read) - - // For intrinsics, the pattern is: // - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat // - operation (optional): _transpose or _transform @@ -626,12 +608,17 @@ ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort2 __builtin_IB_subgroup_block_read_flat_u16_m1k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort4 __builtin_IB_subgroup_block_read_flat_u16_m2k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort8 __builtin_IB_subgroup_block_read_flat_u16_m4k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +ushort16 __builtin_IB_subgroup_block_read_flat_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint8 __builtin_IB_subgroup_block_read_flat_transform_u16_k16(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -659,27 +646,27 @@ void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int wid void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); -ushort intel_subgroup_block_read_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort intel_sub_group_block_read_16b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort2 intel_subgroup_block_read_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort2 intel_sub_group_block_read_16b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort4 intel_subgroup_block_read_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort4 intel_sub_group_block_read_16b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort8 intel_subgroup_block_read_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort8 intel_sub_group_block_read_16b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -ushort16 intel_subgroup_block_read_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +ushort16 intel_sub_group_block_read_16b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) +void intel_sub_group_block_read_16b_32r16c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) { ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); dst[0] = tmp.lo.lo; @@ -688,7 +675,24 @@ void intel_subgroup_block_read_u16_m32k16(const __global void *base_address, int dst[3] = tmp.hi.hi; } -void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) +ushort2 intel_sub_group_block_read_16b_1r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m1k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort4 intel_sub_group_block_read_16b_2r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m2k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort8 intel_sub_group_block_read_16b_4r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m4k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} +ushort16 intel_sub_group_block_read_16b_8r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) +{ + return __builtin_IB_subgroup_block_read_flat_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); +} + +void intel_sub_group_block_read_16b_16r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) { ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); dst[0][0] = tmp.lo.lo; @@ -696,7 +700,7 @@ void intel_subgroup_block_read_u16_m16k16v2(const __global void *base_address, i dst[1][0] = tmp.hi.lo; dst[1][1] = tmp.hi.hi; } -void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) +void intel_sub_group_block_read_16b_32r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) { ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); dst[0][0] = tmp.lo.lo.lo; @@ -709,30 +713,29 @@ void intel_subgroup_block_read_u16_m32k16v2(const __global void *base_address, i dst[1][3] = tmp.hi.hi.hi; } -uint8 intel_subgroup_block_read_u32_m8k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +uint8 intel_sub_group_block_read_32b_8r16c(const __global void* base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint16 intel_subgroup_block_read_u32_m16k16(const __global void* base_address, int width, int height, int pitch, int2 coord) +uint16 intel_sub_group_block_read_32b_16r16c(const __global void* base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -// Each block is K rows x N columns, where the K rows have been VNNI transformed. -int8 intel_subgroup_block_read_transform_u16_k16n16(__global void *base_address, int width, int height, int pitch, int2 coord) +// Each block is K rows x N columns, where the K rows are returned packed into 32-bits. +int8 intel_sub_group_block_read_transform_16b_16r16c(__global void *base_address, int width, int height, int pitch, int2 coord) { - // Note: this function is in the headers, but is named confusingly and returns unsigned integers rather than signed integers: - return as_int8(intel_subgroup_block_read_transform_u16_k16(base_address, width, height, pitch, coord)); + return as_int8(__builtin_IB_subgroup_block_read_flat_transform_u16_k16(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); } -int16 intel_subgroup_block_read_transform_u16_k32n16(__global void *base_address, int width, int height, int pitch, int2 coord) +int16 intel_sub_group_block_read_transform_16b_32r16c(__global void *base_address, int width, int height, int pitch, int2 coord) { return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); } -int16 intel_subgroup_block_read_transform_u16_k16n16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +int16 intel_sub_group_block_read_transform_16b_16r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) { return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); } -void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) +void intel_sub_group_block_read_transform_16b_32r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) { uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); dst[0][0] = as_int8(tmp.lo.lo); @@ -742,69 +745,71 @@ void intel_subgroup_block_read_transform_u16_k32n16v2(__global void *base_addres } +#if !defined(BLOCK_PREFETCH_CACHE_TYPE) #define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C +#endif -void intel_subgroup_block_prefetch_u16_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_8r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m32k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_32r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m16k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_16r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u16_m32k16v2(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_16b_32r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_32b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } -void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +void intel_sub_group_block_prefetch_32b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { #if defined(PREFETCH_DEFAULT) __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); @@ -812,23 +817,23 @@ void intel_subgroup_block_prefetch_u32_m16k16(const __global void *base_address, } -void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +void intel_sub_group_block_write_32b_2r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) { __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +void intel_sub_group_block_write_32b_4r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) { __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) { __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m16k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) +void intel_sub_group_block_write_32b_16r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) { __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 75940825..823a03e1 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -47,16 +47,16 @@ void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, i { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + bData[nn][kk] = load_b_rowmajor_16b_16rNc(B, k + kk * tK, n + nn * tN, N); } } } -void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_load_packed, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); + bData[nn][kk] = load_b_packed_16b_16rNc(B, k + kk * tK, n + nn * tN, N); } } } @@ -67,7 +67,7 @@ void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); + prefetch_a_rowmajor_16b_8r16x2c_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); } } } @@ -76,16 +76,16 @@ void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_16b_16r8x4c_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } -void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +void HELPER_NAME(btile_prefetch_packed_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); + prefetch_b_packed_16b_16r8x2c_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -95,7 +95,7 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); + int16 aTemp = load_a_rowmajor_16b_8r16x2c_sg8(A, m + mm * tM, k + kk * tK, K); aData[kk + 0][mm] = aTemp.lo; aData[kk + 1][mm] = aTemp.hi; } @@ -103,7 +103,7 @@ void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); + aData[kk][mm] = load_a_rowmajor_16b_8r16c_sg8(A, m + mm * tM, k + kk * tK, K); } } } @@ -166,7 +166,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global fl for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -185,7 +185,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_prefetch_packed_sg8, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -209,7 +209,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); int8 bData[NN][KK]; - HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + HELPER_NAME(btile_load_packed, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -228,7 +228,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -239,7 +239,7 @@ void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + prefetch_a_rowmajor_16b_8x2r16x2c_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } } @@ -248,16 +248,16 @@ void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_16b_16r16x2c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } } -void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) +void HELPER_NAME(btile_prefetch_packed, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + prefetch_b_packed_16b_16x2r16c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -267,7 +267,7 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); + short16 aTemp = load_a_rowmajor_16b_8r16x2c_sg16(A, m + mm * tM, k + kk * tK, K); aData[kk + 0][mm] = aTemp.lo; aData[kk + 1][mm] = aTemp.hi; } @@ -275,7 +275,7 @@ void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, i } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk][mm] = load_a_rowmajor_16b_8r16c_sg16(A, m + mm * tM, k + kk * tK, K); } } } @@ -338,7 +338,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global f for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -357,7 +357,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_prefetch_packed, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; } @@ -374,14 +374,14 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float // Next prefetch: // TODO: skip prefetch on the last iterations. HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); + HELPER_NAME(btile_prefetch_packed, MM, NN)(B, tN, N, prefetch_k, n); prefetch_k += tK * KK; short8 aData[KK][MM]; HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); int8 bData[NN][KK]; - HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); + HELPER_NAME(btile_load_packed, MM, NN)(B, tN, N, k, n, bData); for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { @@ -400,7 +400,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -416,7 +416,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); //} ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + intel_sub_group_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); for (int tkk = 0; tkk < 2; tkk++) { for (int tmm = 0; tmm < 4; tmm++) { aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); @@ -428,7 +428,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + intel_sub_group_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); for (int tkk = 0; tkk < 2; tkk++) { for (int tmm = 0; tmm < 2; tmm++) { aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); @@ -439,7 +439,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + short16 aTemp = as_short16(intel_sub_group_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); aData[kk + 0][mm] = aTemp.lo; aData[kk + 1][mm] = aTemp.hi; } @@ -448,7 +448,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm+=4) { ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + intel_sub_group_block_read_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); for (int tmm = 0; tmm < 4; tmm++) { aData[kk][mm + tmm] = as_short8(tmp[tmm]); } @@ -457,7 +457,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + aData[kk][mm] = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); } } } @@ -472,7 +472,7 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); //} int8 tmp[2][2]; - intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); + intel_sub_group_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); for (int tnn = 0; tnn < 2; tnn++) { for (int tkk = 0; tkk < 2; tkk++) { bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; @@ -483,7 +483,7 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in } else if (NN % 2 == 0) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + int16 bTemp = intel_sub_group_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); bData[nn + 0][kk] = bTemp.lo; bData[nn + 1][kk] = bTemp.hi; } @@ -491,7 +491,7 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + int16 bTemp = intel_sub_group_block_read_transform_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); bData[nn][kk + 0] = bTemp.lo; bData[nn][kk + 1] = bTemp.hi; } @@ -499,18 +499,18 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + bData[nn][kk] = intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } } -void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) +void HELPER_NAME(btile_block_load_packed, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) { if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + int16 bTemp = as_int16(intel_sub_group_block_read_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); bData[nn][kk + 0] = bTemp.lo; bData[nn][kk + 1] = bTemp.hi; } @@ -518,7 +518,7 @@ void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + bData[nn][kk] = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); } } } @@ -533,35 +533,35 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM //if (get_sub_group_local_id() == 0) { // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); //} - intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { - intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (KK % 2 == 0 & MM % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (MM % 4 == 0) { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm+=4) { - intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } @@ -576,51 +576,51 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN //if (get_sub_group_local_id() == 0) { // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); //} - intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { - intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else if (NN % 2 == 0) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } } -void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) +void HELPER_NAME(btile_block_prefetch_packed, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) { if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_block_prefetch_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } } } @@ -681,7 +681,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } @@ -699,7 +699,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl int prefetch_k = 0; for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_packed, MM, NN)(B, tN, K, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; } @@ -715,13 +715,13 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int k = 0; k < K; k += tK * KK) { int8 bData[NN][KK]; - HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); + HELPER_NAME(btile_block_load_packed, MM, NN)(B, tN, K, N, k, n, bData); short8 aData[KK][MM]; HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); + HELPER_NAME(btile_block_prefetch_packed, MM, NN)(B, tN, K, N, prefetch_k, n); HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); prefetch_k += tK * KK; @@ -742,7 +742,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index f2254553..7632633d 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -48,13 +48,13 @@ kernel void bfloat16_dpas_rowmajor_m1_n8(global float* C, global ushort* A, glob float sum = 0; for (int k = 0; k < K; k += tK) { - int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + int aData = load_a_rowmajor_16b_1r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -69,13 +69,13 @@ kernel void bfloat16_dpas_rowmajor_m2_n8(global float* C, global ushort* A, glob float2 sum = 0; for (int k = 0; k < K; k += tK) { - int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + int2 aData = load_a_rowmajor_16b_2r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -90,13 +90,13 @@ kernel void bfloat16_dpas_rowmajor_m4_n8(global float* C, global ushort* A, glob float4 sum = 0; for (int k = 0; k < K; k += tK) { - int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + int4 aData = load_a_rowmajor_16b_4r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -111,16 +111,16 @@ kernel void bfloat16_dpas_rowmajor_m8_n8(global float* C, global ushort* A, glob float8 sum = 0; for (int k = 0; k < K; k += tK) { - int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + int8 aData = load_a_rowmajor_16b_8r16c_sg8(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } -// vnni kernels: +// pre-packed kernels: __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global ushort* B, int K) @@ -134,13 +134,13 @@ kernel void bfloat16_dpas_vnni_m1_n8(global float* C, global ushort* A, global u float sum = 0; for (int k = 0; k < K; k += tK) { - int aData = load_a_rowmajor_d16_m1_k16_sg8(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + int aData = load_a_rowmajor_16b_1r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -155,13 +155,13 @@ kernel void bfloat16_dpas_vnni_m2_n8(global float* C, global ushort* A, global u float2 sum = 0; for (int k = 0; k < K; k += tK) { - int2 aData = load_a_rowmajor_d16_m2_k16_sg8(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + int2 aData = load_a_rowmajor_16b_2r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -176,13 +176,13 @@ kernel void bfloat16_dpas_vnni_m4_n8(global float* C, global ushort* A, global u float4 sum = 0; for (int k = 0; k < K; k += tK) { - int4 aData = load_a_rowmajor_d16_m4_k16_sg8(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + int4 aData = load_a_rowmajor_16b_4r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8, 1, 1))) @@ -197,13 +197,13 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u float8 sum = 0; for (int k = 0; k < K; k += tK) { - int8 aData = load_a_rowmajor_d16_m8_k16_sg8(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + int8 aData = load_a_rowmajor_16b_8r16c_sg8(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } #endif // HAS_SIMD8 @@ -222,13 +222,13 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + short aData = load_a_rowmajor_16b_1r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -243,13 +243,13 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + short2 aData = load_a_rowmajor_16b_2r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -264,13 +264,13 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + short4 aData = load_a_rowmajor_16b_4r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -285,16 +285,16 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d16_k16_nx(B, k, n, N); + short8 aData = load_a_rowmajor_16b_8r16c_sg16(A, m, k, K); + int8 bData = load_b_rowmajor_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } -// vnni kernels: +// pre-packed kernels: __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global ushort* B, int K) @@ -308,13 +308,13 @@ kernel void bfloat16_dpas_vnni_m1_n16(global float* C, global ushort* A, global float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = load_a_rowmajor_d16_m1_k16_sg16(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + short aData = load_a_rowmajor_16b_1r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -329,13 +329,13 @@ kernel void bfloat16_dpas_vnni_m2_n16(global float* C, global ushort* A, global float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = load_a_rowmajor_d16_m2_k16_sg16(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + short2 aData = load_a_rowmajor_16b_2r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -350,13 +350,13 @@ kernel void bfloat16_dpas_vnni_m4_n16(global float* C, global ushort* A, global float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = load_a_rowmajor_d16_m4_k16_sg16(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + short4 aData = load_a_rowmajor_16b_4r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -371,13 +371,13 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = load_a_rowmajor_d16_m8_k16_sg16(A, m, k, K); - int8 bData = load_b_vnni_d16_k16_nx(B, k, n, N); + short8 aData = load_a_rowmajor_16b_8r16c_sg16(A, m, k, K); + int8 bData = load_b_packed_16b_16rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } #ifdef cl_intel_subgroup_extended_block_read @@ -395,13 +395,13 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -417,13 +417,13 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -439,13 +439,13 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -461,13 +461,13 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -483,13 +483,13 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u16_m1k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -505,13 +505,13 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u16_m2k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -527,13 +527,13 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u16_m4k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -549,13 +549,13 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); + int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } #endif // cl_intel_subgroup_extended_block_read From 72d7f3b51b76a51a407012ba2fcc2cc2a7a1a24c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 30 Jul 2024 08:41:02 -0700 Subject: [PATCH 78/99] add 32x1 prefetch variants in addition to 16x2 variants --- .../99_matrixexperiments/matrix_helpers.cl | 21 +++++++++++++++++++ .../matrix_kernel_tiled.cl | 6 ++++-- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index f1f5563b..ff3010be 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -639,6 +639,9 @@ void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m8k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m16k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); +void __builtin_IB_subgroup_block_read_prefetch_u16_m32k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); @@ -815,6 +818,24 @@ void intel_sub_group_block_prefetch_32b_16r16c(const __global void *base_address __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); #endif // defined(PREFETCH_DEFAULT) } +void intel_sub_group_block_prefetch_16b_8r32c(__global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m8k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_sub_group_block_prefetch_16b_16r32c(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m16k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} +void intel_sub_group_block_prefetch_16b_32r32c(const __global void *base_address, int width, int height, int pitch, int2 coord) +{ +#if defined(PREFETCH_DEFAULT) + __builtin_IB_subgroup_block_read_prefetch_u16_m32k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); +#endif // defined(PREFETCH_DEFAULT) +} void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 823a03e1..e8b9ac87 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -533,7 +533,8 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM //if (get_sub_group_local_id() == 0) { // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); //} - intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + //intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_block_prefetch_16b_8r32c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { @@ -576,7 +577,8 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN //if (get_sub_group_local_id() == 0) { // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); //} - intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + //intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_block_prefetch_16b_16r32c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { From d3d42f0677aa4521d41d2f1ec1e1d64b48db3ef3 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 30 Jul 2024 09:39:56 -0700 Subject: [PATCH 79/99] add a define for 32x1 prefetch variants --- samples/99_matrixexperiments/matrix_kernel_tiled.cl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index e8b9ac87..9f14b94f 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -533,8 +533,11 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM //if (get_sub_group_local_id() == 0) { // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); //} - //intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); +#ifdef USE_32C intel_sub_group_block_prefetch_16b_8r32c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); +#else + intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); +#endif } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { @@ -577,8 +580,11 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN //if (get_sub_group_local_id() == 0) { // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); //} - //intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); +#ifdef USE_32C intel_sub_group_block_prefetch_16b_16r32c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); +#else + intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); +#endif } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { From eff9d19ed7abe7cfdd3e2e85869af6b175a0f46b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 30 Jul 2024 10:46:18 -0700 Subject: [PATCH 80/99] enable support for the native tf32 dpas --- samples/99_matrixexperimentstf32/main.cpp | 7 ++----- samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl | 7 ------- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/99_matrixexperimentstf32/main.cpp index de21623b..d11dc866 100644 --- a/samples/99_matrixexperimentstf32/main.cpp +++ b/samples/99_matrixexperimentstf32/main.cpp @@ -517,17 +517,14 @@ int main(int argc, char** argv) auto minSubGroupSize = findMinSubGroupSize(device); bool emulate_tN16 = true; - if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { - printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate, min sub-group size is: %zu\n", minSubGroupSize); + if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate_tf32")) { + printf("Found support for cl_intel_subgroup_matrix_multiply_accumulate_tf32, min sub-group size is: %zu\n", minSubGroupSize); switch(minSubGroupSize) { case 16: emulate_tN16 = false; break; default: break; } } - printf("NOTE: dpas is unconditionally emulated, currently!\n"); - emulate_tN16 = true; - buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); printf("Config:\n"); diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index 43174515..1f62e78a 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -71,7 +71,6 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) { float res = acc; -#if 1 res = fma(sub_group_broadcast(a, 0), b.s0, res); res = fma(sub_group_broadcast(a, 1), b.s1, res); res = fma(sub_group_broadcast(a, 2), b.s2, res); @@ -80,12 +79,6 @@ float emu_sub_group_tf32_tf32_matrix_mad_k8(float a, float8 b, float acc) res = fma(sub_group_broadcast(a, 5), b.s5, res); res = fma(sub_group_broadcast(a, 6), b.s6, res); res = fma(sub_group_broadcast(a, 7), b.s7, res); -#else -float __attribute__((overloadable)) intel_sub_group_tf32_tf32_matrix_mad_k8_f32(short a, int8 b, float acc); - uint a_ui = as_uint(sub_group_shuffle(a, get_sub_group_local_id() / 2)); - short aData = get_sub_group_local_id() % 2 ? as_short2(a_ui).hi : as_short2(a_ui).lo; - res = intel_sub_group_tf32_tf32_matrix_mad_k8_f32(aData, as_int8(b), res); -#endif return res; } From af8b5e140bf423aef2e392ad7ec4ef61e83c2568 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 30 Jul 2024 14:18:19 -0700 Subject: [PATCH 81/99] update tf32 function names to be closer to the final versions --- .../99_matrixexperiments/matrix_kernels.cl | 8 +- .../matrix_helpers_tf32.cl | 96 +++++++++---------- .../matrix_kernel_tiled_tf32.cl | 16 ++-- .../matrix_kernels_tf32.cl | 48 +++++----- 4 files changed, 83 insertions(+), 85 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 7632633d..724e80d1 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -396,7 +396,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho float sum = 0; for (int k = 0; k < K; k += tK) { short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -418,7 +418,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho float2 sum = 0; for (int k = 0; k < K; k += tK) { short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -440,7 +440,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho float4 sum = 0; for (int k = 0; k < K; k += tK) { short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } @@ -462,7 +462,7 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho float8 sum = 0; for (int k = 0; k < K; k += tK) { short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u16_k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index 1f62e78a..53ca38fc 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -134,7 +134,7 @@ float8 emu_sub_group_tf32_tf32_matrix_mad_k8(float4 a, float8 b, float8 acc) } // M rows x K columns -float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart, int stride) +float load_a_rowmajor_32b_1r8c_sg16(global float* A, int rowStart, int colStart, int stride) { float ret; @@ -148,7 +148,7 @@ float load_a_rowmajor_d32_m1_k8_sg16(global float* A, int rowStart, int colStart } // M rows x K columns -float load_a_rowmajor_d32_m2_k8_sg16(global float* A, int rowStart, int colStart, int stride) +float load_a_rowmajor_32b_2r8c_sg16(global float* A, int rowStart, int colStart, int stride) { float ret; @@ -162,7 +162,7 @@ float load_a_rowmajor_d32_m2_k8_sg16(global float* A, int rowStart, int colStart } // M rows x K columns -float2 load_a_rowmajor_d32_m4_k8_sg16(global float* A, int rowStart, int colStart, int stride) +float2 load_a_rowmajor_32b_4r8c_sg16(global float* A, int rowStart, int colStart, int stride) { float2 ret; @@ -177,7 +177,7 @@ float2 load_a_rowmajor_d32_m4_k8_sg16(global float* A, int rowStart, int colStar } // M rows x K columns -float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStart, int stride) +float4 load_a_rowmajor_32b_8r8c_sg16(global float* A, int rowStart, int colStart, int stride) { float4 ret; @@ -194,7 +194,7 @@ float4 load_a_rowmajor_d32_m8_k8_sg16(global float* A, int rowStart, int colStar } // M rows x K columns x V tiles (in the M and K dimensions) -void prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(global float* A, int rowStart, int colStart, int stride) +void prefetch_a_rowmajor_32b_8x2r8x2c_sg16(global float* A, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -205,7 +205,7 @@ void prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(global float* A, int rowStart, int c // K rows x N columns: // Each work-item loads K values. // Stride is in units of elements. -float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, int stride) +float8 load_b_rowmajor_32b_8rNc(global float* B, int rowStart, int colStart, int stride) { float8 ret; @@ -224,7 +224,7 @@ float8 load_b_rowmajor_d32_k8_nx(global float* B, int rowStart, int colStart, in } // K rows x N columns x V tiles (in the K and N dimensions) -void prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(global float* B, int rowStart, int colStart, int stride) +void prefetch_b_rowmajor_32b_8x2r8x2c_sg16(global float* B, int rowStart, int colStart, int stride) { #if defined(PREFETCH_DEFAULT) uint offset = colStart + (rowStart + get_sub_group_local_id()) * stride; @@ -232,7 +232,7 @@ void prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(global float* B, int rowStart, int c #endif // defined(PREFETCH_DEFAULT) } -void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_1rNc(global float* C, float v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint v_ui = as_uint(v); @@ -242,7 +242,7 @@ void store_c_rowmajor_fp32_m1_nx(global float* C, float v, int rowStart, int col intel_sub_group_block_write(C_ui + offset, v_ui); offset += stride; } -void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_2rNc(global float* C, float2 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint2 v_ui = as_uint2(v); @@ -253,7 +253,7 @@ void store_c_rowmajor_fp32_m2_nx(global float* C, float2 v, int rowStart, int co intel_sub_group_block_write(C_ui + offset, v_ui.s1); offset += stride; } -void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_4rNc(global float* C, float4 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint4 v_ui = as_uint4(v); @@ -266,7 +266,7 @@ void store_c_rowmajor_fp32_m4_nx(global float* C, float4 v, int rowStart, int co intel_sub_group_block_write(C_ui + offset, v_ui.s3); offset += stride; } -void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int colStart, int stride) +void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int colStart, int stride) { global uint* C_ui = (global uint*)C; uint8 v_ui = as_uint8(v); @@ -295,24 +295,6 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. // - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. -// Built-in functions are: - -// #ifdef cl_intel_subgroup_extended_block_read -// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); -// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); -// #endif //defined(cl_intel_subgroup_extended_block_read) - - // For intrinsics, the pattern is: // - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat // - operation (optional): _transpose or _transform @@ -332,7 +314,18 @@ void store_c_rowmajor_fp32_m8_nx(global float* C, float8 v, int rowStart, int co // - tile width: subgroup size (16) // - number of tiles: 1 -// Define additional "non-vector" block read and writes. These are supported by the hardware but are not in the headers: +enum LSC_LDCC { + LSC_LDCC_DEFAULT = 0, + LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached + LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached + LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached + LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached + LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached + LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached + LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached +}; + +// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: uint __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); @@ -344,65 +337,70 @@ uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(long baseoffset, int wi uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint intel_subgroup_block_read_u32_m1k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); +void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); +void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); +void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); + +uint intel_sub_group_block_read_32b_1r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint intel_subgroup_block_read_u32_m2k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint intel_sub_group_block_read_32b_2r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; } -uint2 intel_subgroup_block_read_u32_m4k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint2 intel_sub_group_block_read_32b_4r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; } -uint4 intel_subgroup_block_read_u32_m8k8(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint4 intel_sub_group_block_read_32b_8r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; } -uint intel_subgroup_block_read_u32_m1k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint intel_sub_group_block_read_32b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint2 intel_subgroup_block_read_u32_m2k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint2 intel_sub_group_block_read_32b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint4 intel_subgroup_block_read_u32_m4k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint4 intel_sub_group_block_read_32b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint8 intel_subgroup_block_read_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) +uint8 intel_sub_group_block_read_32b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 intel_subgroup_block_read_u32_m8k8v2(const __global void* base_address, int width, int height, int pitch, int2 coord) +uint8 intel_sub_group_block_read_32b_8r8x2c(const __global void* base_address, int width, int height, int pitch, int2 coord) { return __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); } -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); -void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) +#if !defined(BLOCK_PREFETCH_CACHE_TYPE) +#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C +#endif + +void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) { __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) +void intel_sub_group_block_write_32b_2r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) { __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) +void intel_sub_group_block_write_32b_4r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) { __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) +void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) { __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index db0bb27c..118d382b 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -47,7 +47,7 @@ void HELPER_NAME(btile_load_rowmajor, MM, NN)(global float* B, int tN, int N, in { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = load_b_rowmajor_d32_k8_nx(B, k + kk * tK, n + nn * tN, N); + bData[nn][kk] = load_b_rowmajor_32b_8rNc(B, k + kk * tK, n + nn * tN, N); } } } @@ -56,7 +56,7 @@ void HELPER_NAME(atile_prefetch_rowmajor_sg16, MM, NN)(global float* A, int tM, { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d32_m8v2_k8v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); + prefetch_a_rowmajor_32b_8x2r8x2c_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); } } } @@ -65,7 +65,7 @@ void HELPER_NAME(btile_prefetch_rowmajor_sg16, MM, NN)(global float* B, int tN, { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d32_k8v2_n8v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); + prefetch_b_rowmajor_32b_8x2r8x2c_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); } } } @@ -74,7 +74,7 @@ void HELPER_NAME(atile_load_rowmajor_sg16, MM, NN)(global float* A, int tM, int { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d32_m8_k8_sg16(A, m + mm * tM, k + kk * tK, K); + aData[kk][mm] = load_a_rowmajor_32b_8r8c_sg16(A, m + mm * tM, k + kk * tK, K); } } } @@ -136,7 +136,7 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float for (int nn = 0; nn < NN; nn++) { for (int mm = 0; mm < MM; mm++) { sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); + store_c_rowmajor_fp32_8rNc(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); } } } @@ -147,7 +147,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + aData[kk][mm] = as_float4(intel_sub_group_block_read_32b_8r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); } } } @@ -156,7 +156,7 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global float* B, int tN, int { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); + bData[nn][kk] = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); } } } @@ -217,7 +217,7 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); } } } diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl index a0f73eb3..aa7ce065 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -40,13 +40,13 @@ kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global f float sum = 0; for (int k = 0; k < K; k += tK) { - float aData = load_a_rowmajor_d32_m1_k8_sg16(A, m, k, K); - float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + float aData = load_a_rowmajor_32b_1r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m1_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_1rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -61,13 +61,13 @@ kernel void tf32_dpas_rowmajor_m2_n16(global float* C, global float* A, global f float2 sum = 0; for (int k = 0; k < K; k += tK) { - float aData = load_a_rowmajor_d32_m2_k8_sg16(A, m, k, K); - float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + float aData = load_a_rowmajor_32b_2r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m2_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_2rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -82,13 +82,13 @@ kernel void tf32_dpas_rowmajor_m4_n16(global float* C, global float* A, global f float4 sum = 0; for (int k = 0; k < K; k += tK) { - float2 aData = load_a_rowmajor_d32_m4_k8_sg16(A, m, k, K); - float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + float2 aData = load_a_rowmajor_32b_4r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m4_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_4rNc(C, sum, m, n, N); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -103,13 +103,13 @@ kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global f float8 sum = 0; for (int k = 0; k < K; k += tK) { - float4 aData = load_a_rowmajor_d32_m8_k8_sg16(A, m, k, K); - float8 bData = load_b_rowmajor_d32_k8_nx(B, k, n, N); + float4 aData = load_a_rowmajor_32b_8r8c_sg16(A, m, k, K); + float8 bData = load_b_rowmajor_32b_8rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - store_c_rowmajor_fp32_m8_nx(C, sum, m, n, N); + store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } #ifdef cl_intel_subgroup_extended_block_read @@ -127,13 +127,13 @@ kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A float sum = 0; for (int k = 0; k < K; k += tK) { - float aData = as_float(intel_subgroup_block_read_u32_m1k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float aData = as_float(intel_sub_group_block_read_32b_1r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -149,13 +149,13 @@ kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A float2 sum = 0; for (int k = 0; k < K; k += tK) { - float aData = as_float(intel_subgroup_block_read_u32_m2k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float aData = as_float(intel_sub_group_block_read_32b_2r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -171,13 +171,13 @@ kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A float4 sum = 0; for (int k = 0; k < K; k += tK) { - float2 aData = as_float2(intel_subgroup_block_read_u32_m4k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float2 aData = as_float2(intel_sub_group_block_read_32b_4r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -193,13 +193,13 @@ kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A float8 sum = 0; for (int k = 0; k < K; k += tK) { - float4 aData = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float4 aData = as_float4(intel_sub_group_block_read_32b_8r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); + float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } #endif // cl_intel_subgroup_extended_block_read From 83a06909c773fdc15f689aa6616bade31b38aaa8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 26 Feb 2025 11:32:28 -0800 Subject: [PATCH 82/99] revert change to tf32 kernel --- .../matrix_kernel_tiled_tf32.cl | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index a4822f60..118d382b 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -145,22 +145,9 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int M, int K, int m, int k, float4 aData[KK][MM]) { - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - //if (get_sub_group_local_id() == 0) { - // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); - //} - float8 aTemp = as_float8(intel_subgroup_block_read_u32_m8k8v2(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_float4(intel_subgroup_block_read_u32_m8k8(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); - } + for (int kk = 0; kk < KK; kk++) { + for (int mm = 0; mm < MM; mm++) { + aData[kk][mm] = as_float4(intel_sub_group_block_read_32b_8r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); } } } From 7e9583146a9ae907c48171cc2b5e25325ee6cbeb Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 26 Feb 2025 11:36:34 -0800 Subject: [PATCH 83/99] fix typo --- samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl index 75940825..9d9225eb 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl @@ -1,5 +1,5 @@ #if !defined(tK) -#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the element type, likely 16 or 32." #endif #if !defined(MM) From 41159a8f999e2dcdcefb5dfbde49e87ef9dca36a Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Wed, 26 Feb 2025 17:01:31 -0800 Subject: [PATCH 84/99] switch block read functions to the production names Tiled kernels still need to be enabled and ported. --- samples/99_matrixexperimentsi8/main.cpp | 8 +++ .../matrix_kernels_i8.cl | 68 ++++++++++++------- 2 files changed, 50 insertions(+), 26 deletions(-) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/99_matrixexperimentsi8/main.cpp index ff99d3c4..dfdcd243 100644 --- a/samples/99_matrixexperimentsi8/main.cpp +++ b/samples/99_matrixexperimentsi8/main.cpp @@ -443,6 +443,8 @@ static void i8_dpas_blockread_rowmajor( cl::Kernel kernel{program, kernelName.c_str()}; if (kernel() == nullptr) { printf("unsupported.\n"); + } else if (K < 64 || N < 64) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { kernel.setArg(0, C); kernel.setArg(1, A); @@ -502,6 +504,8 @@ static void i8_dpas_blockread_rowmajor_tiled( printf("M is too small.\n"); } else if (tN * NN > N) { printf("N is too small.\n"); + } else if (K < 64 || N < 64) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { kernel.setArg(0, C); kernel.setArg(1, A); @@ -555,6 +559,8 @@ static void i8_dpas_blockread_vnni( cl::Kernel kernel{program, kernelName.c_str()}; if (kernel() == nullptr) { printf("unsupported.\n"); + } else if (K < 64 || N < 64/4) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { kernel.setArg(0, C); kernel.setArg(1, A); @@ -614,6 +620,8 @@ static void i8_dpas_blockread_vnni_tiled( printf("M is too small.\n"); } else if (tN * NN > N) { printf("N is too small.\n"); + } else if (K < 64 || N < 64/4) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { kernel.setArg(0, C); kernel.setArg(1, A); diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index 5e6f5735..f75425e1 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -380,7 +380,7 @@ kernel void i8_dpas_vnni_m8_n16(global int* C, global char* A, global char* B, i store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); } -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) @@ -395,13 +395,15 @@ kernel void i8_dpas_blockread_rowmajor_m1_n16(global int* C, global char* A, glo int sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); + short aData; + intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -417,13 +419,15 @@ kernel void i8_dpas_blockread_rowmajor_m2_n16(global int* C, global char* A, glo int2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); + short2 aData; + intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -439,13 +443,15 @@ kernel void i8_dpas_blockread_rowmajor_m4_n16(global int* C, global char* A, glo int4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); + short4 aData; + intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -461,13 +467,15 @@ kernel void i8_dpas_blockread_rowmajor_m8_n16(global int* C, global char* A, glo int8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_transform_u8_k32(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k))); + short8 aData; + intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -483,13 +491,15 @@ kernel void i8_dpas_blockread_vnni_m1_n16(global int* C, global char* A, global int sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_subgroup_block_read_u8_m1k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); + short aData; + intel_sub_group_2d_block_read_8b_1r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m1k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -505,13 +515,15 @@ kernel void i8_dpas_blockread_vnni_m2_n16(global int* C, global char* A, global int2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_subgroup_block_read_u8_m2k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); + short2 aData; + intel_sub_group_2d_block_read_8b_2r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m2k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -527,13 +539,15 @@ kernel void i8_dpas_blockread_vnni_m4_n16(global int* C, global char* A, global int4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_subgroup_block_read_u8_m4k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); + short4 aData; + intel_sub_group_2d_block_read_8b_4r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m4k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -549,16 +563,18 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global int8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_subgroup_block_read_u8_m8k32(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m))); - int8 bData = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4))); + short8 aData; + intel_sub_group_2d_block_read_8b_8r32x1c(A, K * sizeof(char), M, K * sizeof(char), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 4), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io #if 0 // disable the tiled cases for now From ddc93ff81feb24c06c549d4e230792a7fc34208c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 27 Feb 2025 15:13:21 -0800 Subject: [PATCH 85/99] add transpose block read variant --- samples/99_matrixexperimentsi8/main.cpp | 126 ++++++++++++++++++ .../matrix_helpers_i8.cl | 16 +++ .../matrix_kernels_i8.cl | 62 +++++++++ 3 files changed, 204 insertions(+) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/99_matrixexperimentsi8/main.cpp index dfdcd243..04520bfc 100644 --- a/samples/99_matrixexperimentsi8/main.cpp +++ b/samples/99_matrixexperimentsi8/main.cpp @@ -142,6 +142,23 @@ static void compute_reference( } } +template +static void compute_reference_TN( + std::vector& C, + const std::vector& A, const std::vector& B, + size_t M, size_t N, size_t K) +{ + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + DstT sum = 0; + for (size_t k = 0; k < K; k++) { + sum = A[k * K + m] * B[k * N + n] + sum; + } + C[m * N + n] = sum; + } + } +} + template void check_results( size_t M, @@ -660,6 +677,107 @@ static void i8_dpas_blockread_vnni_tiled( } } +static void i8_naive_TN( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); + + cl::Kernel kernel{program, "i8_naive_TN"}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + +template +static void i8_dpas_blockread_rowmajor_TN( + cl::Context& context, cl::Program& program, cl::CommandQueue& queue, + cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, + size_t M, size_t N, size_t K, + const std::vector& C_ref) +{ + printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); + + std::string kernelName = "i8_dpas_blockread_rowmajor_TN"; + kernelName += "_m" + std::to_string(tM); + kernelName += "_n" + std::to_string(tN); + cl::Kernel kernel{program, kernelName.c_str()}; + if (kernel() == nullptr) { + printf("unsupported.\n"); + } else if (K < 64 || N < 64/4) { + printf("matrix pitch for block reads must be >= 64 bytes.\n"); + } else { + kernel.setArg(0, C); + kernel.setArg(1, A); + kernel.setArg(2, B); + kernel.setArg(3, static_cast(K)); + if (roundRobin) { + setRoundRobin(kernel); + } + + if (!skipinit) { + queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); + } + + float best = 999.0f; + for (int test = 0; test < testIterations; test++) { + cl::Event event; + auto start = test_clock::now(); + queue.enqueueNDRangeKernel(kernel, cl::NullRange, + cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + queue.finish(); + auto end = test_clock::now(); + std::chrono::duration sw_time = end - start; + auto elapsed = wallclock ? sw_time.count() : hw_time(event); + best = std::min(best, elapsed); + } + auto gops = 2.0 * M * N * K / best / 1e9; + printf("Best in %f seconds (%f gops)\n", best, gops); + + if (validate) { + printf("Checking results... "); fflush(stdout); + std::vector C_check(C_ref.size()); + queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); + check_results(M, N, C_check, C_ref); + printf(" done!\n"); + } + } +} + int main(int argc, char** argv) { int platformIndex = 0; @@ -784,6 +902,7 @@ int main(int argc, char** argv) std::vector Bvnni_vec(K * N); std::vector C_ref(M * N); + std::vector C_TN_ref(M * N); printf("Initializing source matrices...\n"); fill_matrix(A_vec, M, K); @@ -794,6 +913,8 @@ int main(int argc, char** argv) if (validate) { printf("Computing reference...\n"); compute_reference(C_ref, A_vec, B_vec, M, N, K); + printf("Computing transposed reference...\n"); + compute_reference_TN(C_TN_ref, A_vec, B_vec, M, N, K); } printf("Creating source buffers...\n"); @@ -910,6 +1031,11 @@ int main(int argc, char** argv) i8_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } + if (mask & 0x2000) { + //i8_naive_TN(context, program, queue, C, A, B, M, N, K, C_TN_ref); + i8_dpas_blockread_rowmajor_TN<4, 16>(context, program, queue, C, A, B, M, N, K, C_TN_ref); + } + printf("Done.\n"); return 0; diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl index 26c916c4..d7231d95 100644 --- a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl @@ -742,4 +742,20 @@ void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } +uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); +uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_m32k1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); + +void intel_sub_group_2d_block_read_transpose_32b_16r1x1c(global void* base_address, int width, int height, int pitch, int2 coord, private uint* destination) +{ + uint temp = __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + destination[0] = temp; +} + +void intel_sub_group_2d_block_read_transpose_32b_32r1x1c(global void* base_address, int width, int height, int pitch, int2 coord, private uint* destination) +{ + uint2 temp = __builtin_IB_subgroup_block_read_flat_transpose_u32_m32k1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); + destination[0] = temp.s0; + destination[1] = temp.s1; +} + #endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index f75425e1..ca841f1c 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -27,6 +27,24 @@ kernel void i8_naive(global int* C, global char* A, global char* B, int K) C[m * N + n] = sum; } +kernel void i8_naive_TN(global int* C, global char* A, global char* B, int K) +{ + const int N = get_global_size(0); + const int m = get_global_id(1); + const int n = get_global_id(0); + + int sum = 0; + for (int k = 0; k < K; k++) { + sum = A[k * K + m] * B[k * N + n] + sum; + if (get_global_id(0) == 1 && get_global_id(1) == 0) { + printf("after iteration %d: sum is %d\n", k, sum); + } + } + + sum = activation(sum); + C[m * N + n] = sum; +} + // For all i8 kernels tK == 32: #define tK 32 @@ -574,6 +592,50 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } +__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) +kernel void i8_dpas_blockread_rowmajor_TN_m4_n16(global int* C, global char* A, global char* B, int K) +{ + __builtin_assume(K > 0); // Always at least one K iteration. + const int tM = 4; + const int tN = 16; + const int M = get_global_size(1) * tM; + const int N = get_global_size(0); + const int m = get_group_id(1) * tM; + const int n = get_group_id(0) * tN; + const int sglid = get_sub_group_local_id(); + + int4 sum = 0; + for (int k = 0; k < K; k += tK) { + int2 readData; + intel_sub_group_2d_block_read_transpose_32b_32r1x1c(A, M * sizeof(char), K, M * sizeof(char), (int2)(m / 4, k), (uint*)&readData); + + // Note: after the transpose block read: + // readData.s0 contains row 0-15 + // readData.s1 contains row 16-31 + // So, WI0 has rows 0 and 16, WI1 has rows 1 and 17, etc. + // We want WI0 to have rows 0 and 1, WI1 to have rows 2 and 3, etc. + int shuffledData0 = (sglid < 8) ? + sub_group_shuffle(readData.s0, (sglid * 2)) : + sub_group_shuffle(readData.s1, (sglid * 2) % 16); + int shuffledData1 = (sglid < 8) ? + sub_group_shuffle(readData.s0, (sglid * 2) + 1) : + sub_group_shuffle(readData.s1, (sglid * 2) % 16 + 1); + + short4 aData; + aData.s0 = as_short((char2)(as_char4(shuffledData0).s0, as_char4(shuffledData1).s0)); + aData.s1 = as_short((char2)(as_char4(shuffledData0).s1, as_char4(shuffledData1).s1)); + aData.s2 = as_short((char2)(as_char4(shuffledData0).s2, as_char4(shuffledData1).s2)); + aData.s3 = as_short((char2)(as_char4(shuffledData0).s3, as_char4(shuffledData1).s3)); + + int8 bData; + intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); + sum = mat_mul_sg16(aData, bData, sum); + } + + sum = activation(sum); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); +} + #endif // cl_intel_subgroup_2d_block_io #if 0 // disable the tiled cases for now From 07340978ed8ae50f72035b007934e4cd52c90e17 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 27 Feb 2025 16:23:33 -0800 Subject: [PATCH 86/99] switch to a more efficient sequence with conditional movs --- .../99_matrixexperimentsi8/matrix_kernels_i8.cl | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index ca841f1c..555ac94c 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -610,16 +610,17 @@ kernel void i8_dpas_blockread_rowmajor_TN_m4_n16(global int* C, global char* A, intel_sub_group_2d_block_read_transpose_32b_32r1x1c(A, M * sizeof(char), K, M * sizeof(char), (int2)(m / 4, k), (uint*)&readData); // Note: after the transpose block read: - // readData.s0 contains row 0-15 - // readData.s1 contains row 16-31 + // readData.s0 contains rows 0-15 + // readData.s1 contains rows 16-31 // So, WI0 has rows 0 and 16, WI1 has rows 1 and 17, etc. // We want WI0 to have rows 0 and 1, WI1 to have rows 2 and 3, etc. - int shuffledData0 = (sglid < 8) ? - sub_group_shuffle(readData.s0, (sglid * 2)) : - sub_group_shuffle(readData.s1, (sglid * 2) % 16); - int shuffledData1 = (sglid < 8) ? - sub_group_shuffle(readData.s0, (sglid * 2) + 1) : - sub_group_shuffle(readData.s1, (sglid * 2) % 16 + 1); + int shuffleIndex = sglid * 2 % 16; + int loData0 = sub_group_shuffle(readData.s0, shuffleIndex); + int hiData0 = sub_group_shuffle(readData.s1, shuffleIndex); + int shuffledData0 = (sglid < 8) ? loData0 : hiData0; + int loData1 = sub_group_shuffle(readData.s0, shuffleIndex + 1); + int hiData1 = sub_group_shuffle(readData.s1, shuffleIndex + 1); + int shuffledData1 = (sglid < 8) ? loData1 : hiData1; short4 aData; aData.s0 = as_short((char2)(as_char4(shuffledData0).s0, as_char4(shuffledData1).s0)); From 3324a53e3d35a2ada834cc2c2c949b270eaf5af4 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Thu, 27 Feb 2025 16:24:44 -0800 Subject: [PATCH 87/99] cleanup --- samples/99_matrixexperimentsi8/matrix_kernels_i8.cl | 3 --- 1 file changed, 3 deletions(-) diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index 555ac94c..0a707bc3 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -36,9 +36,6 @@ kernel void i8_naive_TN(global int* C, global char* A, global char* B, int K) int sum = 0; for (int k = 0; k < K; k++) { sum = A[k * K + m] * B[k * N + n] + sum; - if (get_global_id(0) == 1 && get_global_id(1) == 0) { - printf("after iteration %d: sum is %d\n", k, sum); - } } sum = activation(sum); From 62a1fd896b196a5d66c831cd6ac0adf164d90e0b Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 3 May 2025 21:33:47 -0700 Subject: [PATCH 88/99] switch to production 2d block io functions --- .../99_matrixexperiments/matrix_helpers.cl | 307 ------------------ .../matrix_kernel_tiled.cl | 108 +++--- .../99_matrixexperiments/matrix_kernels.cl | 68 ++-- .../matrix_helpers_tf32.cl | 4 +- .../matrix_kernel_tiled_tf32.cl | 4 +- .../matrix_kernels_tf32.cl | 4 +- 6 files changed, 103 insertions(+), 392 deletions(-) diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/99_matrixexperiments/matrix_helpers.cl index ff3010be..d8580aa7 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/99_matrixexperiments/matrix_helpers.cl @@ -553,310 +553,3 @@ void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int col } #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) - -#ifdef cl_intel_subgroup_extended_block_read - -// Note for 2D block reads: -// - the tile width and height is encoded into the function name. -// - base_address is the byte address. Must be 64B aligned. -// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. -// - height is the height of the entire matrix, or equivalently the number of rows. -// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. -// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. - -// For intrinsics, the pattern is: -// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat -// - operation (optional): _transpose or _transform -// - for no transpose or transform: -// - type / elements size: _u8 or _u16 or _u32 or _u64 -// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 -// - tile width: _k64 or _k32 or _k16 or _k8 -// - number of tiles: _v2 or _v1 -// - for transpose: -// - type / element size: _u64 or _u32 -// - number of tile rows: subgroup size (16) -// - tile width: _k4 (for _u64) or _k8 (for _u32) -// - number of tiles: 1 -// - for transform: -// - type / element size: _u16 or _u8 -// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) -// - tile width: subgroup size (16) -// - number of tiles: 1 - -enum LSC_LDCC { - LSC_LDCC_DEFAULT = 0, - LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached - LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached - LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached - LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached - LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached -}; - -typedef ushort __attribute__((ext_vector_type(32))) ushort32; -typedef ushort __attribute__((ext_vector_type(64))) ushort64; - -typedef uint __attribute__((ext_vector_type(32))) uint32; - -// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: - -ushort __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort16 __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort32 __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -ushort2 __builtin_IB_subgroup_block_read_flat_u16_m1k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u16_m2k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u16_m4k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort16 __builtin_IB_subgroup_block_read_flat_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort32 __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort64 __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint16 __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 __builtin_IB_subgroup_block_read_flat_transform_u16_k16(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k32(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint16 __builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint32 __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - - -void __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_read_prefetch_u16_m8k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m16k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); -void __builtin_IB_subgroup_block_read_prefetch_u16_m32k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, enum LSC_LDCC cache_control); - -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); -void __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint16 data); - -ushort intel_sub_group_block_read_16b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort2 intel_sub_group_block_read_16b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort4 intel_sub_group_block_read_16b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort8 intel_sub_group_block_read_16b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort16 intel_sub_group_block_read_16b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -void intel_sub_group_block_read_16b_32r16c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[4]) -{ - ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0] = tmp.lo.lo; - dst[1] = tmp.lo.hi; - dst[2] = tmp.hi.lo; - dst[3] = tmp.hi.hi; -} - -ushort2 intel_sub_group_block_read_16b_1r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m1k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort4 intel_sub_group_block_read_16b_2r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m2k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort8 intel_sub_group_block_read_16b_4r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m4k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort16 intel_sub_group_block_read_16b_8r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - -void intel_sub_group_block_read_16b_16r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][2]) -{ - ushort32 tmp = __builtin_IB_subgroup_block_read_flat_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = tmp.lo.lo; - dst[0][1] = tmp.lo.hi; - dst[1][0] = tmp.hi.lo; - dst[1][1] = tmp.hi.hi; -} -void intel_sub_group_block_read_16b_32r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord, ushort8 dst[2][4]) -{ - ushort64 tmp = __builtin_IB_subgroup_block_read_flat_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = tmp.lo.lo.lo; - dst[0][1] = tmp.lo.lo.hi; - dst[0][2] = tmp.lo.hi.lo; - dst[0][3] = tmp.lo.hi.hi; - dst[1][0] = tmp.hi.lo.lo; - dst[1][1] = tmp.hi.lo.hi; - dst[1][2] = tmp.hi.hi.lo; - dst[1][3] = tmp.hi.hi.hi; -} - -uint8 intel_sub_group_block_read_32b_8r16c(const __global void* base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -uint16 intel_sub_group_block_read_32b_16r16c(const __global void* base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - -// Each block is K rows x N columns, where the K rows are returned packed into 32-bits. -int8 intel_sub_group_block_read_transform_16b_16r16c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return as_int8(__builtin_IB_subgroup_block_read_flat_transform_u16_k16(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); -} -int16 intel_sub_group_block_read_transform_16b_32r16c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k32(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); -} -int16 intel_sub_group_block_read_transform_16b_16r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ - return as_int16(__builtin_IB_subgroup_block_read_flat_transform_u16_k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord)); -} -void intel_sub_group_block_read_transform_16b_32r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord, int8 dst[2][2]) -{ - uint32 tmp = __builtin_IB_subgroup_block_read_flat_transform_u16_k32v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - dst[0][0] = as_int8(tmp.lo.lo); - dst[0][1] = as_int8(tmp.lo.hi); - dst[1][0] = as_int8(tmp.hi.lo); - dst[1][1] = as_int8(tmp.hi.hi); -} - - -#if !defined(BLOCK_PREFETCH_CACHE_TYPE) -#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C -#endif - -void intel_sub_group_block_prefetch_16b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_8r16x2c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m8k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_32r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_16r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m16k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_32r16x2c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m32k16v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_32b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_32b_16r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_8r32c(__global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m8k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_16r32c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m16k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} -void intel_sub_group_block_prefetch_16b_32r32c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ -#if defined(PREFETCH_DEFAULT) - __builtin_IB_subgroup_block_read_prefetch_u16_m32k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, BLOCK_PREFETCH_CACHE_TYPE); -#endif // defined(PREFETCH_DEFAULT) -} - - -void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_2r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_4r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_16r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint16 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m16k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} - -#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/99_matrixexperiments/matrix_kernel_tiled.cl index 9f14b94f..9007c3b9 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/99_matrixexperiments/matrix_kernel_tiled.cl @@ -405,7 +405,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float } } -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) { @@ -415,11 +415,11 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in //if (get_sub_group_local_id() == 0) { // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); //} - ushort8 tmp[2][4]; - intel_sub_group_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + short8 aTemp[2][4]; + intel_sub_group_2d_block_read_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); for (int tkk = 0; tkk < 2; tkk++) { for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm]; } } } @@ -427,11 +427,11 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in } else if (KK % 2 == 0 & MM % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_sub_group_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + short8 aTemp[2][2]; + intel_sub_group_2d_block_read_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); for (int tkk = 0; tkk < 2; tkk++) { for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); + aData[kk + tkk][mm + tmm] = aTemp[tkk][tmm]; } } } @@ -439,25 +439,28 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, in } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_sub_group_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; + short8 aTemp[2]; + intel_sub_group_2d_block_read_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + aData[kk + 0][mm] = aTemp[0]; + aData[kk + 1][mm] = aTemp[1]; } } } else if (MM % 4 == 0) { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_sub_group_block_read_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); + short8 aTemp[4]; + intel_sub_group_2d_block_read_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); + aData[kk][mm + tmm] = aTemp[tmm]; } } } } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); + short8 aTemp[1]; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), (ushort*)aTemp); + aData[kk][mm] = aTemp[0]; } } } @@ -471,11 +474,11 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in //if (get_sub_group_local_id() == 0) { // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); //} - int8 tmp[2][2]; - intel_sub_group_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); + int8 bTemp[2][2]; + intel_sub_group_2d_block_read_transform_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); for (int tnn = 0; tnn < 2; tnn++) { for (int tkk = 0; tkk < 2; tkk++) { - bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; + bData[nn + tnn][kk + tkk] = bTemp[tnn][tkk]; } } } @@ -483,23 +486,27 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, in } else if (NN % 2 == 0) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - int16 bTemp = intel_sub_group_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[nn + 0][kk] = bTemp.lo; - bData[nn + 1][kk] = bTemp.hi; + int8 bTemp[2]; + intel_sub_group_2d_block_read_transform_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn + 0][kk] = bTemp[0]; + bData[nn + 1][kk] = bTemp[1]; } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - int16 bTemp = intel_sub_group_block_read_transform_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[nn][kk + 0] = bTemp.lo; - bData[nn][kk + 1] = bTemp.hi; + int8 bTemp[2]; + intel_sub_group_2d_block_read_transform_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn][kk + 0] = bTemp[0]; + bData[nn][kk + 1] = bTemp[1]; } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + int8 bTemp[1]; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), (uint*)bTemp); + bData[nn][kk] = bTemp[0]; } } } @@ -510,15 +517,18 @@ void HELPER_NAME(btile_block_load_packed, MM, NN)(global ushort* B, int tN, int if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - int16 bTemp = as_int16(intel_sub_group_block_read_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - bData[nn][kk + 0] = bTemp.lo; - bData[nn][kk + 1] = bTemp.hi; + int8 bTemp[2]; + intel_sub_group_2d_block_read_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp); + bData[nn][kk + 0] = bTemp[0]; + bData[nn][kk + 1] = bTemp[1]; } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); + int8 bTemp[1]; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2), (uint*)bTemp); + bData[nn][kk] = bTemp[0]; } } } @@ -533,39 +543,35 @@ void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM //if (get_sub_group_local_id() == 0) { // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); //} -#ifdef USE_32C - intel_sub_group_block_prefetch_16b_8r32c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); -#else - intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); -#endif + intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } else if (KK % 2 == 0 & MM % 4 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=4) { - intel_sub_group_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_2d_block_prefetch_16b_32r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (KK % 2 == 0 & MM % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm+=2) { - intel_sub_group_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_2d_block_prefetch_16b_16r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int mm = 0; mm < MM; mm++) { - intel_sub_group_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_2d_block_prefetch_16b_8r16x2c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else if (MM % 4 == 0) { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm+=4) { - intel_sub_group_block_prefetch_16b_32r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_2d_block_prefetch_16b_32r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - intel_sub_group_block_prefetch_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); + intel_sub_group_2d_block_prefetch_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); } } } @@ -580,33 +586,29 @@ void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN //if (get_sub_group_local_id() == 0) { // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); //} -#ifdef USE_32C - intel_sub_group_block_prefetch_16b_16r32c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); -#else - intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); -#endif + intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } else if (KK % 2 == 0 & NN % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn += 2) { - intel_sub_group_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_2d_block_prefetch_16b_32r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else if (NN % 2 == 0) { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn+=2) { - intel_sub_group_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_2d_block_prefetch_16b_16r16x2c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - intel_sub_group_block_prefetch_16b_32r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_2d_block_prefetch_16b_32r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - intel_sub_group_block_prefetch_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); + intel_sub_group_2d_block_prefetch_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); } } } @@ -618,17 +620,17 @@ void HELPER_NAME(btile_block_prefetch_packed, MM, NN)(global ushort* B, int tN, const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 - intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } else if (KK % 2 == 0) { for (int kk = 0; kk < KK; kk+=2) { for (int nn = 0; nn < NN; nn++) { - intel_sub_group_block_prefetch_32b_16r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_2d_block_prefetch_32b_16r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } } } else { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - intel_sub_group_block_prefetch_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); + intel_sub_group_2d_block_prefetch_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); } } } @@ -689,7 +691,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); } } } @@ -750,9 +752,9 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); } } } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/99_matrixexperiments/matrix_kernels.cl index 724e80d1..47fa2704 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/99_matrixexperiments/matrix_kernels.cl @@ -380,7 +380,7 @@ kernel void bfloat16_dpas_vnni_m8_n16(global float* C, global ushort* A, global store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) @@ -395,13 +395,15 @@ kernel void bfloat16_dpas_blockread_rowmajor_m1_n16(global float* C, global usho float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short aData; + intel_sub_group_2d_block_read_16b_1r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -417,13 +419,15 @@ kernel void bfloat16_dpas_blockread_rowmajor_m2_n16(global float* C, global usho float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short2 aData; + intel_sub_group_2d_block_read_16b_2r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -439,13 +443,15 @@ kernel void bfloat16_dpas_blockread_rowmajor_m4_n16(global float* C, global usho float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short4 aData; + intel_sub_group_2d_block_read_16b_4r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -461,13 +467,15 @@ kernel void bfloat16_dpas_blockread_rowmajor_m8_n16(global float* C, global usho float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_transform_16b_16r16c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k))); + short8 aData; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData);; + int8 bData; + intel_sub_group_2d_block_read_transform_16b_16r16x1c(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -483,13 +491,15 @@ kernel void bfloat16_dpas_blockread_vnni_m1_n16(global float* C, global ushort* float sum = 0; for (int k = 0; k < K; k += tK) { - short aData = as_short(intel_sub_group_block_read_16b_1r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short aData; + intel_sub_group_2d_block_read_16b_1r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -505,13 +515,15 @@ kernel void bfloat16_dpas_blockread_vnni_m2_n16(global float* C, global ushort* float2 sum = 0; for (int k = 0; k < K; k += tK) { - short2 aData = as_short2(intel_sub_group_block_read_16b_2r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short2 aData; + intel_sub_group_2d_block_read_16b_2r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -527,13 +539,15 @@ kernel void bfloat16_dpas_blockread_vnni_m4_n16(global float* C, global ushort* float4 sum = 0; for (int k = 0; k < K; k += tK) { - short4 aData = as_short4(intel_sub_group_block_read_16b_4r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short4 aData; + intel_sub_group_2d_block_read_16b_4r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -549,16 +563,18 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* float8 sum = 0; for (int k = 0; k < K; k += tK) { - short8 aData = as_short8(intel_sub_group_block_read_16b_8r16c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m))); - int8 bData = as_int8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2))); + short8 aData; + intel_sub_group_2d_block_read_16b_8r16x1c(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k, m), (ushort*)&aData); + int8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n, k / 2), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io // Tiled matrix multiplication kernels, generated from a template: diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index 53ca38fc..cccefba7 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -285,7 +285,7 @@ void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int col #endif // defined(cl_intel_subgroups) -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io // Note for 2D block reads: // - the tile width and height is encoded into the function name. @@ -405,4 +405,4 @@ void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int widt __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index 118d382b..b51a9490 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -141,7 +141,7 @@ kernel void MM_KERNEL_NAME(tf32_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float } } -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int M, int K, int m, int k, float4 aData[KK][MM]) { @@ -222,4 +222,4 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl } } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl index aa7ce065..dbe60a56 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -112,7 +112,7 @@ kernel void tf32_dpas_rowmajor_m8_n16(global float* C, global float* A, global f store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } -#ifdef cl_intel_subgroup_extended_block_read +#ifdef cl_intel_subgroup_2d_block_io __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) @@ -202,7 +202,7 @@ kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); } -#endif // cl_intel_subgroup_extended_block_read +#endif // cl_intel_subgroup_2d_block_io // Tiled matrix multiplication kernels, generated from a template: From badf4c2e519930008e944e14d08351d1af670b10 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 3 May 2025 21:44:34 -0700 Subject: [PATCH 89/99] switch more block reads to the production versions --- .../matrix_helpers_tf32.cl | 122 ------------------ .../matrix_kernel_tiled_tf32.cl | 6 +- .../matrix_kernels_tf32.cl | 32 +++-- 3 files changed, 23 insertions(+), 137 deletions(-) diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl index cccefba7..25a8bbe9 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl @@ -284,125 +284,3 @@ void store_c_rowmajor_fp32_8rNc(global float* C, float8 v, int rowStart, int col } #endif // defined(cl_intel_subgroups) - -#ifdef cl_intel_subgroup_2d_block_io - -// Note for 2D block reads: -// - the tile width and height is encoded into the function name. -// - base_address is the byte address. Must be 64B aligned. -// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. -// - height is the height of the entire matrix, or equivalently the number of rows. -// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. -// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. - -// For intrinsics, the pattern is: -// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat -// - operation (optional): _transpose or _transform -// - for no transpose or transform: -// - type / elements size: _u8 or _u16 or _u32 or _u64 -// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 -// - tile width: _k64 or _k32 or _k16 or _k8 -// - number of tiles: _v2 or _v1 -// - for transpose: -// - type / element size: _u64 or _u32 -// - number of tile rows: subgroup size (16) -// - tile width: _k4 (for _u64) or _k8 (for _u32) -// - number of tiles: 1 -// - for transform: -// - type / element size: _u16 or _u8 -// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) -// - tile width: subgroup size (16) -// - number of tiles: 1 - -enum LSC_LDCC { - LSC_LDCC_DEFAULT = 0, - LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached - LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached - LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached - LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached - LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached -}; - -// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: - -uint __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint4 __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); - -uint intel_sub_group_block_read_32b_1r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m1k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -uint intel_sub_group_block_read_32b_2r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m2k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; -} -uint2 intel_sub_group_block_read_32b_4r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m4k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; -} -uint4 intel_sub_group_block_read_32b_8r8c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m8k8v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord).lo; -} - -uint intel_sub_group_block_read_32b_1r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -uint2 intel_sub_group_block_read_32b_2r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -uint4 intel_sub_group_block_read_32b_4r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -uint8 intel_sub_group_block_read_32b_8r16c(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - -uint8 intel_sub_group_block_read_32b_8r8x2c(const __global void* base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m8k8v2(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - - -#if !defined(BLOCK_PREFETCH_CACHE_TYPE) -#define BLOCK_PREFETCH_CACHE_TYPE LSC_LDCC_L1C_L3C -#endif - -void intel_sub_group_block_write_32b_1r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_2r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_4r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_sub_group_block_write_32b_8r16c(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} - -#endif // cl_intel_subgroup_2d_block_io diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl index b51a9490..ff94a4c1 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl @@ -147,7 +147,7 @@ void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global float* A, int tM, int { for (int kk = 0; kk < KK; kk++) { for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_float4(intel_sub_group_block_read_32b_8r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM))); + intel_sub_group_2d_block_read_32b_8r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k + kk * tK, m + mm * tM), (uint*)&aData[kk][mm]); } } } @@ -156,7 +156,7 @@ void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global float* B, int tN, int { for (int kk = 0; kk < KK; kk++) { for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK))); + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n + nn * tN, k + kk * tK), (uint*)&bData[nn][kk]); } } } @@ -217,7 +217,7 @@ kernel void MM_KERNEL_NAME(tf32_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(gl for (int mm = 0; mm < MM; mm++) { for (int nn = 0; nn < NN; nn++) { sum[nn][mm] = activation(sum[nn][mm]); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), (uint*)&sum[nn][mm]); } } } diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl index dbe60a56..953a9a6e 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl @@ -127,13 +127,15 @@ kernel void tf32_dpas_blockread_rowmajor_m1_n16(global float* C, global float* A float sum = 0; for (int k = 0; k < K; k += tK) { - float aData = as_float(intel_sub_group_block_read_32b_1r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float aData; + intel_sub_group_2d_block_read_32b_1r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_1r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint(sum)); + intel_sub_group_2d_block_write_32b_1r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -149,13 +151,15 @@ kernel void tf32_dpas_blockread_rowmajor_m2_n16(global float* C, global float* A float2 sum = 0; for (int k = 0; k < K; k += tK) { - float aData = as_float(intel_sub_group_block_read_32b_2r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float aData; + intel_sub_group_2d_block_read_32b_2r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_2r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint2(sum)); + intel_sub_group_2d_block_write_32b_2r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -171,13 +175,15 @@ kernel void tf32_dpas_blockread_rowmajor_m4_n16(global float* C, global float* A float4 sum = 0; for (int k = 0; k < K; k += tK) { - float2 aData = as_float2(intel_sub_group_block_read_32b_4r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float2 aData; + intel_sub_group_2d_block_read_32b_4r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_4r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint4(sum)); + intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) @@ -193,13 +199,15 @@ kernel void tf32_dpas_blockread_rowmajor_m8_n16(global float* C, global float* A float8 sum = 0; for (int k = 0; k < K; k += tK) { - float4 aData = as_float4(intel_sub_group_block_read_32b_8r8c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m))); - float8 bData = as_float8(intel_sub_group_block_read_32b_8r16c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k))); + float4 aData; + intel_sub_group_2d_block_read_32b_8r8x1c(A, K * sizeof(float), M, K * sizeof(float), (int2)(k, m), (uint*)&aData); + float8 bData; + intel_sub_group_2d_block_read_32b_8r16x1c(B, N * sizeof(float), K, N * sizeof(float), (int2)(n, k), (uint*)&bData); sum = mat_mul_sg16(aData, bData, sum); } sum = activation(sum); - intel_sub_group_block_write_32b_8r16c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), as_uint8(sum)); + intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } #endif // cl_intel_subgroup_2d_block_io From 9988e6d43c8d07832b0b12ca43099b5a2fed6e1c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 3 May 2025 21:57:57 -0700 Subject: [PATCH 90/99] integrate i8 matrix multiplication --- samples/99_matrixexperimentsi8/CMakeLists.txt | 2 +- samples/99_matrixexperimentsi8/main.cpp | 427 ---------- .../matrix_helpers_i8.cl | 130 --- .../matrix_kernel_tiled_i8.cl | 750 ------------------ .../matrix_kernels_i8.cl | 108 --- 5 files changed, 1 insertion(+), 1416 deletions(-) delete mode 100644 samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl diff --git a/samples/99_matrixexperimentsi8/CMakeLists.txt b/samples/99_matrixexperimentsi8/CMakeLists.txt index b97f9c74..c9493112 100644 --- a/samples/99_matrixexperimentsi8/CMakeLists.txt +++ b/samples/99_matrixexperimentsi8/CMakeLists.txt @@ -8,4 +8,4 @@ add_opencl_sample( TARGET matrixexperimentsi8 VERSION 200 # for clSetKernelExecInfo SOURCES main.cpp - KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl matrix_kernel_tiled_i8.cl) + KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/99_matrixexperimentsi8/main.cpp index 04520bfc..5741c455 100644 --- a/samples/99_matrixexperimentsi8/main.cpp +++ b/samples/99_matrixexperimentsi8/main.cpp @@ -142,23 +142,6 @@ static void compute_reference( } } -template -static void compute_reference_TN( - std::vector& C, - const std::vector& A, const std::vector& B, - size_t M, size_t N, size_t K) -{ - for (size_t m = 0; m < M; m++) { - for (size_t n = 0; n < N; n++) { - DstT sum = 0; - for (size_t k = 0; k < K; k++) { - sum = A[k * K + m] * B[k * N + n] + sum; - } - C[m * N + n] = sum; - } - } -} - template void check_results( size_t M, @@ -283,62 +266,6 @@ static void i8_dpas_rowmajor( } } -template -static void i8_dpas_rowmajor_tiled( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); - - std::string kernelName = "i8_dpas_rowmajor_tiled"; - kernelName += "_m" + std::to_string(tM); - kernelName += "_n" + std::to_string(tN); - kernelName += "_" + std::to_string(MM); - kernelName += "x" + std::to_string(NN); - cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else if (tM * MM > M) { - printf("M is too small.\n"); - } else if (tN * NN > N) { - printf("N is too small.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - template static void i8_dpas_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, @@ -389,62 +316,6 @@ static void i8_dpas_vnni( } } -template -static void i8_dpas_vnni_tiled( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); - - std::string kernelName = "i8_dpas_vnni_tiled"; - kernelName += "_m" + std::to_string(tM); - kernelName += "_n" + std::to_string(tN); - kernelName += "_" + std::to_string(MM); - kernelName += "x" + std::to_string(NN); - cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else if (tM * MM > M) { - printf("M is too small.\n"); - } else if (tN * NN > N) { - printf("N is too small.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - template static void i8_dpas_blockread_rowmajor( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, @@ -500,67 +371,6 @@ static void i8_dpas_blockread_rowmajor( } } -template -static void i8_dpas_blockread_rowmajor_tiled( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); - - std::string kernelName = "i8_dpas_blockread_rowmajor_tiled"; - kernelName += "_m" + std::to_string(tM); - kernelName += "_n" + std::to_string(tN); - kernelName += "_" + std::to_string(MM); - kernelName += "x" + std::to_string(NN); - cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else if (tM * MM > M) { - printf("M is too small.\n"); - } else if (tN * NN > N) { - printf("N is too small.\n"); - } else if (K < 64 || N < 64) { - printf("matrix pitch for block reads must be >= 64 bytes.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - if (roundRobin) { - setRoundRobin(kernel); - } - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - template static void i8_dpas_blockread_vnni( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, @@ -616,168 +426,6 @@ static void i8_dpas_blockread_vnni( } } -template -static void i8_dpas_blockread_vnni_tiled( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, MM, NN, M, N, K).c_str()); fflush(stdout); - - std::string kernelName = "i8_dpas_blockread_vnni_tiled"; - kernelName += "_m" + std::to_string(tM); - kernelName += "_n" + std::to_string(tN); - kernelName += "_" + std::to_string(MM); - kernelName += "x" + std::to_string(NN); - cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else if (tM * MM > M) { - printf("M is too small.\n"); - } else if (tN * NN > N) { - printf("N is too small.\n"); - } else if (K < 64 || N < 64/4) { - printf("matrix pitch for block reads must be >= 64 bytes.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - if (roundRobin) { - setRoundRobin(kernel); - } - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - -static void i8_naive_TN( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, M, N, K).c_str()); fflush(stdout); - - cl::Kernel kernel{program, "i8_naive_TN"}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - -template -static void i8_dpas_blockread_rowmajor_TN( - cl::Context& context, cl::Program& program, cl::CommandQueue& queue, - cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, - size_t M, size_t N, size_t K, - const std::vector& C_ref) -{ - printf("%80s: ", makeTestName(__FUNCTION__, tM, tN, M, N, K).c_str()); fflush(stdout); - - std::string kernelName = "i8_dpas_blockread_rowmajor_TN"; - kernelName += "_m" + std::to_string(tM); - kernelName += "_n" + std::to_string(tN); - cl::Kernel kernel{program, kernelName.c_str()}; - if (kernel() == nullptr) { - printf("unsupported.\n"); - } else if (K < 64 || N < 64/4) { - printf("matrix pitch for block reads must be >= 64 bytes.\n"); - } else { - kernel.setArg(0, C); - kernel.setArg(1, A); - kernel.setArg(2, B); - kernel.setArg(3, static_cast(K)); - if (roundRobin) { - setRoundRobin(kernel); - } - - if (!skipinit) { - queue.enqueueFillBuffer(C, 0, 0, C_ref.size() * sizeof(C_ref[0])); - } - - float best = 999.0f; - for (int test = 0; test < testIterations; test++) { - cl::Event event; - auto start = test_clock::now(); - queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); - queue.finish(); - auto end = test_clock::now(); - std::chrono::duration sw_time = end - start; - auto elapsed = wallclock ? sw_time.count() : hw_time(event); - best = std::min(best, elapsed); - } - auto gops = 2.0 * M * N * K / best / 1e9; - printf("Best in %f seconds (%f gops)\n", best, gops); - - if (validate) { - printf("Checking results... "); fflush(stdout); - std::vector C_check(C_ref.size()); - queue.enqueueReadBuffer(C, CL_TRUE, 0, C_check.size() * sizeof(C_check[0]), C_check.data()); - check_results(M, N, C_check, C_ref); - printf(" done!\n"); - } - } -} - int main(int argc, char** argv) { int platformIndex = 0; @@ -902,7 +550,6 @@ int main(int argc, char** argv) std::vector Bvnni_vec(K * N); std::vector C_ref(M * N); - std::vector C_TN_ref(M * N); printf("Initializing source matrices...\n"); fill_matrix(A_vec, M, K); @@ -913,8 +560,6 @@ int main(int argc, char** argv) if (validate) { printf("Computing reference...\n"); compute_reference(C_ref, A_vec, B_vec, M, N, K); - printf("Computing transposed reference...\n"); - compute_reference_TN(C_TN_ref, A_vec, B_vec, M, N, K); } printf("Creating source buffers...\n"); @@ -936,33 +581,6 @@ int main(int argc, char** argv) i8_dpas_rowmajor<8, 8>(context, program, queue, C, A, B, M, N, K, C_ref); } - if (mask & 0x4) { - i8_dpas_rowmajor_tiled<8, 8, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 8, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - } - - if (mask & 0x8) { - i8_dpas_vnni<1, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni<2, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni<4, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni<8, 8>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - } - - if (mask & 0x10) { - i8_dpas_vnni_tiled<8, 8, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 8, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - } - if (mask & 0x20) { i8_dpas_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); i8_dpas_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -970,16 +588,6 @@ int main(int argc, char** argv) i8_dpas_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); } - if (mask & 0x40) { - i8_dpas_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - } - if (mask & 0x80) { i8_dpas_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); i8_dpas_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -987,16 +595,6 @@ int main(int argc, char** argv) i8_dpas_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } - if (mask & 0x100) { - i8_dpas_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - } - if (mask & 0x200) { i8_dpas_blockread_rowmajor<1, 16>(context, program, queue, C, A, B, M, N, K, C_ref); i8_dpas_blockread_rowmajor<2, 16>(context, program, queue, C, A, B, M, N, K, C_ref); @@ -1004,16 +602,6 @@ int main(int argc, char** argv) i8_dpas_blockread_rowmajor<8, 16>(context, program, queue, C, A, B, M, N, K, C_ref); } - if (mask & 0x400) { - i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 1>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 1, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 2>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 2, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - i8_dpas_blockread_rowmajor_tiled<8, 16, 4, 4>(context, program, queue, C, A, B, M, N, K, C_ref); - } - if (mask & 0x800) { i8_dpas_blockread_vnni<1, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); i8_dpas_blockread_vnni<2, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); @@ -1021,21 +609,6 @@ int main(int argc, char** argv) i8_dpas_blockread_vnni<8, 16>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); } - if (mask & 0x1000) { - i8_dpas_blockread_vnni_tiled<8, 16, 1, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 2, 1>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 1, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 2, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 4, 2>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 2, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - i8_dpas_blockread_vnni_tiled<8, 16, 4, 4>(context, program, queue, C, A, Bvnni, M, N, K, C_ref); - } - - if (mask & 0x2000) { - //i8_naive_TN(context, program, queue, C, A, B, M, N, K, C_TN_ref); - i8_dpas_blockread_rowmajor_TN<4, 16>(context, program, queue, C, A, B, M, N, K, C_TN_ref); - } - printf("Done.\n"); return 0; diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl index d7231d95..d380ca2e 100644 --- a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl @@ -629,133 +629,3 @@ void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colSt } #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) - -#ifdef cl_intel_subgroup_extended_block_read - -// Note for 2D block reads: -// - the tile width and height is encoded into the function name. -// - base_address is the byte address. Must be 64B aligned. -// - width is the width of the entire matrix, in bytes. Must be >= 64B. Must be 4B aligned. -// - height is the height of the entire matrix, or equivalently the number of rows. -// - pitch is the number of bytes between rows of the entire matrix. Must be >= 64B. Must be a multiple of 8 bytes. -// - coord is the number of elements (x coord) and row (y coord) to read from. X coord must be multiple 4 for for 1B data and 2 for 2B data. - -// Built-in functions are: - -// #ifdef cl_intel_subgroup_extended_block_read -// ushort2 intel_subgroup_block_read_u8_m1k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u8_m2k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u8_m4k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u8_m8k32v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort2 intel_subgroup_block_read_u16_m1k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort4 intel_subgroup_block_read_u16_m2k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort8 intel_subgroup_block_read_u16_m4k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// ushort16 intel_subgroup_block_read_u16_m8k16v2(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u8_k32(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transform_u16_k16(__global void *base_address, int width, int height, int pitch, int2 coord); -// uint8 intel_subgroup_block_read_transpose_u32_k8(__global void *base_address, int width, int height, int pitch, int2 coord); -// ulong4 intel_subgroup_block_read_transpose_u64_k4(__global void *base_address, int width, int height, int pitch, int2 coord); -// #endif //defined(cl_intel_subgroup_extended_block_read) - - -// For intrinsics, the pattern is: -// - prefix: __builtin_IB_subgroup_block_read_flat or __builtin_IB_subgroup_block_write_flat -// - operation (optional): _transpose or _transform -// - for no transpose or transform: -// - type / elements size: _u8 or _u16 or _u32 or _u64 -// - number of tile rows: _m32 or _m16 or _m8 or _m4 or _m2 or _m1 -// - tile width: _k64 or _k32 or _k16 or _k8 -// - number of tiles: _v2 or _v1 -// - for transpose: -// - type / element size: _u64 or _u32 -// - number of tile rows: subgroup size (16) -// - tile width: _k4 (for _u64) or _k8 (for _u32) -// - number of tiles: 1 -// - for transform: -// - type / element size: _u16 or _u8 -// - number of tile rows: _k32 (for _u8) or _k16 (for _u16) -// - tile width: subgroup size (16) -// - number of tiles: 1 - -enum LSC_LDCC { - LSC_LDCC_DEFAULT = 0, - LSC_LDCC_L1UC_L3UC = 1, // Override to L1 uncached and L3 uncached - LSC_LDCC_L1UC_L3C = 2, // Override to L1 uncached and L3 cached - LSC_LDCC_L1C_L3UC = 3, // Override to L1 cached and L3 uncached - LSC_LDCC_L1C_L3C = 4, // Override to L1 cached and L3 cached - LSC_LDCC_L1S_L3UC = 5, // Override to L1 streaming load and L3 uncached - LSC_LDCC_L1S_L3C = 6, // Override to L1 streaming load and L3 cached - LSC_LDCC_L1IAR_L3C = 7, // Override to L1 invalidate-after-read, and L3 cached -}; - -// Define block reads, prefetches, and writes. These are supported by the hardware but are not in the headers: - -ushort __builtin_IB_subgroup_block_read_flat_u8_m1k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort2 __builtin_IB_subgroup_block_read_flat_u8_m2k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort4 __builtin_IB_subgroup_block_read_flat_u8_m4k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -ushort8 __builtin_IB_subgroup_block_read_flat_u8_m8k32v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -uint8 __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -ushort intel_subgroup_block_read_u8_m1k32(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u8_m1k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort2 intel_subgroup_block_read_u8_m2k32(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u8_m2k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort4 intel_subgroup_block_read_u8_m4k32(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u8_m4k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} -ushort8 intel_subgroup_block_read_u8_m8k32(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u8_m8k32v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - -uint8 intel_subgroup_block_read_u32_m8k16(const __global void *base_address, int width, int height, int pitch, int2 coord) -{ - return __builtin_IB_subgroup_block_read_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); -} - - -void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint data); -void __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint2 data); -void __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint4 data); -void __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord, uint8 data); - -void intel_subgroup_block_write_u32_m1k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_subgroup_block_write_u32_m2k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint2 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m2k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_subgroup_block_write_u32_m4k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint4 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m4k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} -void intel_subgroup_block_write_u32_m8k16(__global void* base_address, int width, int height, int pitch, int2 coord, uint8 data) -{ - __builtin_IB_subgroup_block_write_flat_u32_m8k16v1(as_long(base_address), width - 1, height - 1, pitch - 1, coord, data); -} - -uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); -uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_m32k1(long baseoffset, int width_minus_one, int height_minus_one, int pitch_minus_one, int2 coord); - -void intel_sub_group_2d_block_read_transpose_32b_16r1x1c(global void* base_address, int width, int height, int pitch, int2 coord, private uint* destination) -{ - uint temp = __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - destination[0] = temp; -} - -void intel_sub_group_2d_block_read_transpose_32b_32r1x1c(global void* base_address, int width, int height, int pitch, int2 coord, private uint* destination) -{ - uint2 temp = __builtin_IB_subgroup_block_read_flat_transpose_u32_m32k1(as_long(base_address), width - 1, height - 1, pitch - 1, coord); - destination[0] = temp.s0; - destination[1] = temp.s1; -} - -#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl deleted file mode 100644 index 9d9225eb..00000000 --- a/samples/99_matrixexperimentsi8/matrix_kernel_tiled_i8.cl +++ /dev/null @@ -1,750 +0,0 @@ -#if !defined(tK) -#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the element type, likely 16 or 32." -#endif - -#if !defined(MM) -#error "MM is undefined! This should be defined as the number of matrix tiles in the M dimension." -#endif - -#if !defined(NN) -#error "NN is undefined! This should be defined as the number of matrix tiles in the N dimension." -#endif - -#if !defined(KK) -#define KK 1 -#endif - -#if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) -#if !defined(cl_intel_split_work_group_barrier) -#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" -#endif -#define split_barrier_arrive() -#define split_barrier_wait() -#else -#define split_barrier_arrive() intel_work_group_barrier_arrive(0) -#define split_barrier_wait() intel_work_group_barrier_wait(0) -#endif - -#define MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) PREFIX ## _m ## tM ## _n ## tN ## _ ## MM ## x ## NN -#define MM_KERNEL_NAME(PREFIX, tM, tN, MM, NN) MM_KERNEL_NAMEX(PREFIX, tM, tN, MM, NN) - -#define HELPER_NAMEX(PREFIX, MM, NN) PREFIX ## _m ## MM ## _n ## NN -#define HELPER_NAME(PREFIX, MM, NN) HELPER_NAMEX(PREFIX, MM, NN) - -#if !defined(SGS_PER_WG_X) -#define SGS_PER_WG_X 1 -#endif - -#if !defined(SGS_PER_WG_Y) -#define SGS_PER_WG_Y 4 -#endif - -#if !defined(PREFETCH_DISTANCE) -#define PREFETCH_DISTANCE 1 -#endif - -void HELPER_NAME(btile_load_rowmajor, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = load_b_rowmajor_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(btile_load_vnni, MM, NN)(global ushort* B, int tN, int N, int k, int n, int8 bData[NN][KK]) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = load_b_vnni_d16_k16_nx(B, k + kk * tK, n + nn * tN, N); - } - } -} - -#if HAS_SIMD8 - -void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) -{ - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - prefetch_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } -} - -void HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=4) { - prefetch_b_rowmajor_d16_k16_n8v4_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_vnni_d16_k16_n8v2_sg8(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int k, int8 aData[KK][MM]) -{ - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - int16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg8(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg8(A, m + mm * tM, k + kk * tK, K); - } - } - } -} - -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 8; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - // Initial prefetch: - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - - int8 aData[KK][MM]; - HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); - - int8 bData[NN][KK]; - HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); - - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); - } - } -} - -__attribute__((intel_reqd_sub_group_size(8))) __attribute__((reqd_work_group_size(8 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 8; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - // Initial prefetch: - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni_sg8, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor_sg8, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - - int8 aData[KK][MM]; - HELPER_NAME(atile_load_rowmajor_sg8, MM, NN)(A, tM, K, m, k, aData); - - int8 bData[NN][KK]; - HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = mat_mul_sg8(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); - } - } -} - -#endif // HAS_SIMD8 - -void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) -{ - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(A, m + mm * tM, prefetch_k + kk * tK, K); - } - } -} - -void HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) -{ - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - prefetch_b_rowmajor_d16_k16_n16v2_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(btile_prefetch_vnni, MM, NN)(global ushort* B, int tN, int N, int prefetch_k, int n) -{ - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - prefetch_b_vnni_d16_k16v2_n16_sg16(B, prefetch_k + kk * tK, n + nn * tN, N); - } - } -} - -void HELPER_NAME(atile_load_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int k, short8 aData[KK][MM]) -{ - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = load_a_rowmajor_d16_m8_k16v2_sg16(A, m + mm * tM, k + kk * tK, K); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = load_a_rowmajor_d16_m8_k16_sg16(A, m + mm * tM, k + kk * tK, K); - } - } - } -} - -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 16; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - // Initial prefetch: - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_rowmajor, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); - - int8 bData[NN][KK]; - HELPER_NAME(btile_load_rowmajor, MM, NN)(B, tN, N, k, n, bData); - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); - } - } -} - -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 16, MM, NN)(global float* C, global_aligned_ushort_ptr A, global_aligned_ushort_ptr B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 16; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - // Initial prefetch: - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - // Next prefetch: - // TODO: skip prefetch on the last iterations. - HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(A, tM, K, m, prefetch_k); - HELPER_NAME(btile_prefetch_vnni, MM, NN)(B, tN, N, prefetch_k, n); - prefetch_k += tK * KK; - - short8 aData[KK][MM]; - HELPER_NAME(atile_load_rowmajor, MM, NN)(A, tM, K, m, k, aData); - - int8 bData[NN][KK]; - HELPER_NAME(btile_load_vnni, MM, NN)(B, tN, N, k, n, bData); - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = activation(sum[nn][mm]); - store_c_rowmajor_fp32_m8_nx(C, sum[nn][mm], m + mm * tM, n + nn * tN, N); - } - } -} - -#ifdef cl_intel_subgroup_extended_block_read - -void HELPER_NAME(atile_block_load_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k, short8 aData[KK][MM]) -{ - if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - //if (get_sub_group_local_id() == 0) { - // printf("atile block load : %d, %d, %2d: m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), m, k, mm, kk, k + kk * tK, m + mm * tM); - //} - ushort8 tmp[2][4]; - intel_subgroup_block_read_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - ushort8 tmp[2][2]; - intel_subgroup_block_read_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tkk = 0; tkk < 2; tkk++) { - for (int tmm = 0; tmm < 2; tmm++) { - aData[kk + tkk][mm + tmm] = as_short8(tmp[tkk][tmm]); - } - } - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - short16 aTemp = as_short16(intel_subgroup_block_read_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - aData[kk + 0][mm] = aTemp.lo; - aData[kk + 1][mm] = aTemp.hi; - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - ushort8 tmp[4]; - intel_subgroup_block_read_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM), tmp); - for (int tmm = 0; tmm < 4; tmm++) { - aData[kk][mm + tmm] = as_short8(tmp[tmm]); - } - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - aData[kk][mm] = as_short8(intel_subgroup_block_read_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM))); - } - } - } -} - -void HELPER_NAME(btile_block_load_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) -{ - if (KK % 2 == 0 & NN % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn+=2) { - //if (get_sub_group_local_id() == 0) { - // printf("btile block load: %d, %d, %2d: n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), n, k, nn, kk, n + nn * tN, k + kk * tK); - //} - int8 tmp[2][2]; - intel_subgroup_block_read_transform_u16_k32n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK), tmp); - for (int tnn = 0; tnn < 2; tnn++) { - for (int tkk = 0; tkk < 2; tkk++) { - bData[nn + tnn][kk + tkk] = tmp[tnn][tkk]; - } - } - } - } - } else if (NN % 2 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - int16 bTemp = intel_subgroup_block_read_transform_u16_k16n16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[nn + 0][kk] = bTemp.lo; - bData[nn + 1][kk] = bTemp.hi; - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - int16 bTemp = intel_subgroup_block_read_transform_u16_k32n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - bData[nn][kk + 0] = bTemp.lo; - bData[nn][kk + 1] = bTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = intel_subgroup_block_read_transform_u16_k16n16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } - } - } -} - -void HELPER_NAME(btile_block_load_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n, int8 bData[NN][KK]) -{ - if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - int16 bTemp = as_int16(intel_subgroup_block_read_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - bData[nn][kk + 0] = bTemp.lo; - bData[nn][kk + 1] = bTemp.hi; - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - bData[nn][kk] = as_int8(intel_subgroup_block_read_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2))); - } - } - } -} - -void HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int M, int K, int m, int k) -{ - if (KK == 2 & MM == 4 & SGS_PER_WG_X >= 4) { - const int sg_index_x = get_sub_group_id() % SGS_PER_WG_X; // index in [0, SGS_PER_WG_X) - const int kk = 0; - const int mm = sg_index_x % 4; - //if (get_sub_group_local_id() == 0) { - // printf("atile block prefetch: %d, %d, %2d: sg_x = %d, m = %3d, k = %3d, mm = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_x, m, k, mm, kk, k + kk * tK, m + mm * tM); - //} - intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } else if (KK % 2 == 0 & MM % 4 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=4) { - intel_subgroup_block_prefetch_u16_m32k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } - } - } else if (KK % 2 == 0 & MM % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm+=2) { - intel_subgroup_block_prefetch_u16_m16k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int mm = 0; mm < MM; mm++) { - intel_subgroup_block_prefetch_u16_m8k16v2(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } - } - } else if (MM % 4 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm+=4) { - intel_subgroup_block_prefetch_u16_m32k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int mm = 0; mm < MM; mm++) { - intel_subgroup_block_prefetch_u16_m8k16(A, K * sizeof(ushort), M, K * sizeof(ushort), (int2)(k + kk * tK, m + mm * tM)); - } - } - } -} - -void HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) -{ - if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { - const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) - const int nn = sg_index_y % 2 * 2; // nn(sg_index_y) == 0, 2, 0, 2, 0, 2, 0, 2, ... - const int kk = sg_index_y / 2 % 2; // kk(sg_index_y) == 0, 0, 1, 1, 0, 0, 1, 1, ... - //if (get_sub_group_local_id() == 0) { - // printf("btile block prefetch: %d, %d, %2d: sg_y = %d, n = %3d, k = %3d, nn = %2d, kk = %2d, coord = %3d, %3d\n", (int)get_group_id(1), (int)get_group_id(0), get_sub_group_id(), sg_index_y, n, k, nn, kk, n + nn * tN, k + kk * tK); - //} - intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } else if (KK % 2 == 0 & NN % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn += 2) { - intel_subgroup_block_prefetch_u16_m32k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } - } - } else if (NN % 2 == 0) { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn+=2) { - intel_subgroup_block_prefetch_u16_m16k16v2(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } - } - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u16_m32k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u16_m16k16(B, N * sizeof(ushort), K, N * sizeof(ushort), (int2)(n + nn * tN, k + kk * tK)); - } - } - } -} - -void HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(global ushort* B, int tN, int K, int N, int k, int n) -{ - if (KK == 2 & NN == 4 & SGS_PER_WG_Y >= 4) { - const int sg_index_y = get_sub_group_id() / SGS_PER_WG_X; // index in [0, SGS_PER_WG_Y) - const int nn = sg_index_y % 4; // nn(sg_index_y) == 0, 1, 2, 3, 0, 1, 2, 3 - const int kk = 0; // kk(sg_index_y) == 0, 0, 0, 0, 0, 0, 0, 0 - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); - } else if (KK % 2 == 0) { - for (int kk = 0; kk < KK; kk+=2) { - for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u32_m16k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); - } - } - } else { - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - intel_subgroup_block_prefetch_u32_m8k16(B, N * sizeof(uint), K, N * sizeof(uint), (int2)(n + nn * tN, (k + kk * tK) / 2)); - } - } - } -} - -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_rowmajor_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 16; - const int M = get_global_size(1) * tM * MM; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - int8 bData[NN][KK]; - HELPER_NAME(btile_block_load_rowmajor, MM, NN)(B, tN, K, N, k, n, bData); - - short8 aData[KK][MM]; - HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); - - HELPER_NAME(btile_block_prefetch_rowmajor, MM, NN)(B, tN, K, N, prefetch_k, n); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = activation(sum[nn][mm]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); - } - } -} - -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16 * SGS_PER_WG_X, SGS_PER_WG_Y, 1))) -kernel void MM_KERNEL_NAME(bfloat16_dpas_blockread_vnni_tiled, 8, 16, MM, NN)(global float* C, global ushort* A, global ushort* B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 8; - const int tN = 16; - const int M = get_global_size(1) * tM * MM; - const int N = get_global_size(0) * NN; - const int m = compute_m(SGS_PER_WG_X, SGS_PER_WG_Y, tM, MM); - const int n = compute_n(SGS_PER_WG_X, SGS_PER_WG_Y, tN, NN); - - int prefetch_k = 0; - for (int p = 0; p < PREFETCH_DISTANCE; p++) { - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - } - - float8 sum[NN][MM]; - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = 0; - } - } - - split_barrier_arrive(); - - for (int k = 0; k < K; k += tK * KK) { - int8 bData[NN][KK]; - HELPER_NAME(btile_block_load_vnni, MM, NN)(B, tN, K, N, k, n, bData); - - short8 aData[KK][MM]; - HELPER_NAME(atile_block_load_rowmajor, MM, NN)(A, tM, M, K, m, k, aData); - - // TODO: skip prefetch on the last iterations. - HELPER_NAME(btile_block_prefetch_vnni, MM, NN)(B, tN, K, N, prefetch_k, n); - HELPER_NAME(atile_block_prefetch_rowmajor, MM, NN)(A, tM, M, K, m, prefetch_k); - prefetch_k += tK * KK; - - for (int kk = 0; kk < KK; kk++) { - for (int nn = 0; nn < NN; nn++) { - for (int mm = 0; mm < MM; mm++) { - sum[nn][mm] = mat_mul_sg16(aData[kk][mm], bData[nn][kk], sum[nn][mm]); - } - } - } - - split_barrier_wait(); - split_barrier_arrive(); - } - - split_barrier_wait(); - - for (int mm = 0; mm < MM; mm++) { - for (int nn = 0; nn < NN; nn++) { - sum[nn][mm] = activation(sum[nn][mm]); - intel_subgroup_block_write_u32_m8k16(C, N * sizeof(float), M, N * sizeof(float), (int2)(n + nn * tN, m + mm * tM), as_uint8(sum[nn][mm])); - } - } -} - -#endif // cl_intel_subgroup_extended_block_read diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl index 0a707bc3..80938417 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl @@ -27,21 +27,6 @@ kernel void i8_naive(global int* C, global char* A, global char* B, int K) C[m * N + n] = sum; } -kernel void i8_naive_TN(global int* C, global char* A, global char* B, int K) -{ - const int N = get_global_size(0); - const int m = get_global_id(1); - const int n = get_global_id(0); - - int sum = 0; - for (int k = 0; k < K; k++) { - sum = A[k * K + m] * B[k * N + n] + sum; - } - - sum = activation(sum); - C[m * N + n] = sum; -} - // For all i8 kernels tK == 32: #define tK 32 @@ -589,101 +574,8 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global intel_sub_group_2d_block_write_32b_8r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); } -__attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) -kernel void i8_dpas_blockread_rowmajor_TN_m4_n16(global int* C, global char* A, global char* B, int K) -{ - __builtin_assume(K > 0); // Always at least one K iteration. - const int tM = 4; - const int tN = 16; - const int M = get_global_size(1) * tM; - const int N = get_global_size(0); - const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * tN; - const int sglid = get_sub_group_local_id(); - - int4 sum = 0; - for (int k = 0; k < K; k += tK) { - int2 readData; - intel_sub_group_2d_block_read_transpose_32b_32r1x1c(A, M * sizeof(char), K, M * sizeof(char), (int2)(m / 4, k), (uint*)&readData); - - // Note: after the transpose block read: - // readData.s0 contains rows 0-15 - // readData.s1 contains rows 16-31 - // So, WI0 has rows 0 and 16, WI1 has rows 1 and 17, etc. - // We want WI0 to have rows 0 and 1, WI1 to have rows 2 and 3, etc. - int shuffleIndex = sglid * 2 % 16; - int loData0 = sub_group_shuffle(readData.s0, shuffleIndex); - int hiData0 = sub_group_shuffle(readData.s1, shuffleIndex); - int shuffledData0 = (sglid < 8) ? loData0 : hiData0; - int loData1 = sub_group_shuffle(readData.s0, shuffleIndex + 1); - int hiData1 = sub_group_shuffle(readData.s1, shuffleIndex + 1); - int shuffledData1 = (sglid < 8) ? loData1 : hiData1; - - short4 aData; - aData.s0 = as_short((char2)(as_char4(shuffledData0).s0, as_char4(shuffledData1).s0)); - aData.s1 = as_short((char2)(as_char4(shuffledData0).s1, as_char4(shuffledData1).s1)); - aData.s2 = as_short((char2)(as_char4(shuffledData0).s2, as_char4(shuffledData1).s2)); - aData.s3 = as_short((char2)(as_char4(shuffledData0).s3, as_char4(shuffledData1).s3)); - - int8 bData; - intel_sub_group_2d_block_read_transform_8b_32r16x1c(B, N * sizeof(char), K, N * sizeof(char), (int2)(n, k), (uint*)&bData); - sum = mat_mul_sg16(aData, bData, sum); - } - - sum = activation(sum); - intel_sub_group_2d_block_write_32b_4r16x1c(C, N * sizeof(float), M, N * sizeof(float), (int2)(n, m), (uint*)&sum); -} - #endif // cl_intel_subgroup_2d_block_io -#if 0 // disable the tiled cases for now - -// Tiled matrix multiplication kernels, generated from a template: - -#define MM 1 -#define NN 1 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 2 -#define NN 1 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 1 -#define NN 2 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 2 -#define NN 2 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 4 -#define NN 2 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 2 -#define NN 4 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#define MM 4 -#define NN 4 -#include "matrix_kernel_tiled_i8.cl" -#undef MM -#undef NN - -#endif // disable the tiled cases for now - #endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) #undef tK From d680ff14e48f5b7725600e02aab1ea517061cdc8 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sat, 3 May 2025 22:09:00 -0700 Subject: [PATCH 91/99] switch to final directories and sample names --- .../20_matrixexperiments-bf16/CMakeLists.txt | 11 ++++++++++ .../main.cpp | 4 ++-- .../matrix_helpers_bf16.cl} | 6 +++++ .../matrix_kernel_tiled_bf16.cl} | 6 +++++ .../matrix_kernels_bf16.cl} | 22 ++++++++++++------- .../CMakeLists.txt | 6 ++--- .../main.cpp | 2 +- .../matrix_helpers_i8.cl | 6 +++++ .../matrix_kernels_i8.cl | 6 +++++ .../CMakeLists.txt | 6 ++--- .../main.cpp | 0 .../matrix_helpers_tf32.cl | 6 +++++ .../matrix_kernel_tiled_tf32.cl | 6 +++++ .../matrix_kernels_tf32.cl | 6 +++++ samples/99_matrixexperiments/CMakeLists.txt | 11 ---------- samples/CMakeLists.txt | 7 +++--- 16 files changed, 80 insertions(+), 31 deletions(-) create mode 100644 samples/20_matrixexperiments-bf16/CMakeLists.txt rename samples/{99_matrixexperiments => 20_matrixexperiments-bf16}/main.cpp (99%) rename samples/{99_matrixexperiments/matrix_helpers.cl => 20_matrixexperiments-bf16/matrix_helpers_bf16.cl} (99%) rename samples/{99_matrixexperiments/matrix_kernel_tiled.cl => 20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl} (99%) rename samples/{99_matrixexperiments/matrix_kernels.cl => 20_matrixexperiments-bf16/matrix_kernels_bf16.cl} (98%) rename samples/{99_matrixexperimentsi8 => 20_matrixexperiments-i8}/CMakeLists.txt (67%) rename samples/{99_matrixexperimentsi8 => 20_matrixexperiments-i8}/main.cpp (99%) rename samples/{99_matrixexperimentsi8 => 20_matrixexperiments-i8}/matrix_helpers_i8.cl (99%) rename samples/{99_matrixexperimentsi8 => 20_matrixexperiments-i8}/matrix_kernels_i8.cl (99%) rename samples/{99_matrixexperimentstf32 => 20_matrixexperiments-tf32}/CMakeLists.txt (70%) rename samples/{99_matrixexperimentstf32 => 20_matrixexperiments-tf32}/main.cpp (100%) rename samples/{99_matrixexperimentstf32 => 20_matrixexperiments-tf32}/matrix_helpers_tf32.cl (99%) rename samples/{99_matrixexperimentstf32 => 20_matrixexperiments-tf32}/matrix_kernel_tiled_tf32.cl (98%) rename samples/{99_matrixexperimentstf32 => 20_matrixexperiments-tf32}/matrix_kernels_tf32.cl (99%) delete mode 100644 samples/99_matrixexperiments/CMakeLists.txt diff --git a/samples/20_matrixexperiments-bf16/CMakeLists.txt b/samples/20_matrixexperiments-bf16/CMakeLists.txt new file mode 100644 index 00000000..0acf9ca0 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/CMakeLists.txt @@ -0,0 +1,11 @@ +# Copyright (c) 2024-2025 Ben Ashbaugh +# +# SPDX-License-Identifier: MIT + +add_opencl_sample( + TEST + NUMBER 20 + TARGET matrixexperiments-bf16 + VERSION 200 # for clSetKernelExecInfo + SOURCES main.cpp + KERNELS matrix_helpers_bf16.cl matrix_kernels_bf16.cl matrix_kernel_tiled_bf16.cl) diff --git a/samples/99_matrixexperiments/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp similarity index 99% rename from samples/99_matrixexperiments/main.cpp rename to samples/20_matrixexperiments-bf16/main.cpp index fae01b3f..2aa3e74f 100644 --- a/samples/99_matrixexperiments/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2019-2024 Ben Ashbaugh +// Copyright (c) 2024-2025 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ @@ -662,7 +662,7 @@ int main(int argc, char** argv) int platformIndex = 0; int deviceIndex = 0; - std::string fileName("matrix_kernels.cl"); + std::string fileName("matrix_kernels_bf16.cl"); std::string buildOptions; size_t matrixSize = 512; diff --git a/samples/99_matrixexperiments/matrix_helpers.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl similarity index 99% rename from samples/99_matrixexperiments/matrix_helpers.cl rename to samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl index d8580aa7..49961017 100644 --- a/samples/99_matrixexperiments/matrix_helpers.cl +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + float bf16_to_fp32(ushort u) { #if defined(cl_intel_bfloat16_conversions) diff --git a/samples/99_matrixexperiments/matrix_kernel_tiled.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl similarity index 99% rename from samples/99_matrixexperiments/matrix_kernel_tiled.cl rename to samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl index 9007c3b9..ead10880 100644 --- a/samples/99_matrixexperiments/matrix_kernel_tiled.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + #if !defined(tK) #error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." #endif diff --git a/samples/99_matrixexperiments/matrix_kernels.cl b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl similarity index 98% rename from samples/99_matrixexperiments/matrix_kernels.cl rename to samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl index 47fa2704..2e2c46a9 100644 --- a/samples/99_matrixexperiments/matrix_kernels.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl @@ -1,4 +1,10 @@ -#include "matrix_helpers.cl" +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + +#include "matrix_helpers_bf16.cl" #if EMULATE_tN8 #define mat_mul_sg8 emu_sub_group_bf16_bf16_matrix_mad_k16 @@ -580,43 +586,43 @@ kernel void bfloat16_dpas_blockread_vnni_m8_n16(global float* C, global ushort* #define MM 1 #define NN 1 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 2 #define NN 1 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 1 #define NN 2 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 2 #define NN 2 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 4 #define NN 2 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 2 #define NN 4 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN #define MM 4 #define NN 4 -#include "matrix_kernel_tiled.cl" +#include "matrix_kernel_tiled_bf16.cl" #undef MM #undef NN diff --git a/samples/99_matrixexperimentsi8/CMakeLists.txt b/samples/20_matrixexperiments-i8/CMakeLists.txt similarity index 67% rename from samples/99_matrixexperimentsi8/CMakeLists.txt rename to samples/20_matrixexperiments-i8/CMakeLists.txt index c9493112..a8acd899 100644 --- a/samples/99_matrixexperimentsi8/CMakeLists.txt +++ b/samples/20_matrixexperiments-i8/CMakeLists.txt @@ -1,11 +1,11 @@ -# Copyright (c) 2019-2024 Ben Ashbaugh +# Copyright (c) 2024-2025 Ben Ashbaugh # # SPDX-License-Identifier: MIT add_opencl_sample( TEST - NUMBER 99 - TARGET matrixexperimentsi8 + NUMBER 20 + TARGET matrixexperiments-i8 VERSION 200 # for clSetKernelExecInfo SOURCES main.cpp KERNELS matrix_helpers_i8.cl matrix_kernels_i8.cl) diff --git a/samples/99_matrixexperimentsi8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp similarity index 99% rename from samples/99_matrixexperimentsi8/main.cpp rename to samples/20_matrixexperiments-i8/main.cpp index 5741c455..6a97cae3 100644 --- a/samples/99_matrixexperimentsi8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2019-2024 Ben Ashbaugh +// Copyright (c) 2024-2025 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl similarity index 99% rename from samples/99_matrixexperimentsi8/matrix_helpers_i8.cl rename to samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index d380ca2e..27ef25cf 100644 --- a/samples/99_matrixexperimentsi8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + __attribute__((overloadable)) int activation(int i) { diff --git a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl similarity index 99% rename from samples/99_matrixexperimentsi8/matrix_kernels_i8.cl rename to samples/20_matrixexperiments-i8/matrix_kernels_i8.cl index 80938417..10100ded 100644 --- a/samples/99_matrixexperimentsi8/matrix_kernels_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + #include "matrix_helpers_i8.cl" #if EMULATE_tN8 diff --git a/samples/99_matrixexperimentstf32/CMakeLists.txt b/samples/20_matrixexperiments-tf32/CMakeLists.txt similarity index 70% rename from samples/99_matrixexperimentstf32/CMakeLists.txt rename to samples/20_matrixexperiments-tf32/CMakeLists.txt index 5987f780..0329e267 100644 --- a/samples/99_matrixexperimentstf32/CMakeLists.txt +++ b/samples/20_matrixexperiments-tf32/CMakeLists.txt @@ -1,11 +1,11 @@ -# Copyright (c) 2019-2024 Ben Ashbaugh +# Copyright (c) 2024-2025 Ben Ashbaugh # # SPDX-License-Identifier: MIT add_opencl_sample( TEST - NUMBER 99 - TARGET matrixexperimentstf32 + NUMBER 20 + TARGET matrixexperiments-tf32 VERSION 200 # for clSetKernelExecInfo SOURCES main.cpp KERNELS matrix_helpers_tf32.cl matrix_kernels_tf32.cl matrix_kernel_tiled_tf32.cl) diff --git a/samples/99_matrixexperimentstf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp similarity index 100% rename from samples/99_matrixexperimentstf32/main.cpp rename to samples/20_matrixexperiments-tf32/main.cpp diff --git a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl similarity index 99% rename from samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl rename to samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl index 25a8bbe9..bf81a4b3 100644 --- a/samples/99_matrixexperimentstf32/matrix_helpers_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + __attribute__((overloadable)) float activation(float f) { diff --git a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl similarity index 98% rename from samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl rename to samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl index ff94a4c1..c93e717e 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernel_tiled_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + #if !defined(tK) #error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." #endif diff --git a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl similarity index 99% rename from samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl rename to samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl index 953a9a6e..ea4526db 100644 --- a/samples/99_matrixexperimentstf32/matrix_kernels_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl @@ -1,3 +1,9 @@ +/* +// Copyright (c) 2024-2025 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ + #include "matrix_helpers_tf32.cl" #if EMULATE_tN16 diff --git a/samples/99_matrixexperiments/CMakeLists.txt b/samples/99_matrixexperiments/CMakeLists.txt deleted file mode 100644 index 9fe36d84..00000000 --- a/samples/99_matrixexperiments/CMakeLists.txt +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2019-2024 Ben Ashbaugh -# -# SPDX-License-Identifier: MIT - -add_opencl_sample( - TEST - NUMBER 99 - TARGET matrixexperiments - VERSION 200 # for clSetKernelExecInfo - SOURCES main.cpp - KERNELS matrix_helpers.cl matrix_kernels.cl matrix_kernel_tiled.cl) diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index 20db83a0..35f101be 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -78,6 +78,10 @@ add_subdirectory( 06_ndrangekernelfromfile ) add_subdirectory( 10_queueexperiments ) add_subdirectory( 16_floatatomics ) +add_subdirectory( 20_matrixexperiments-bf16 ) +add_subdirectory( 20_matrixexperiments-i8 ) +add_subdirectory( 20_matrixexperiments-tf32 ) + set(BUILD_EXTENSION_SAMPLES TRUE) if(NOT TARGET OpenCLExt) message(STATUS "Skipping Extension Samples - OpenCL Extension Loader is not found.") @@ -93,6 +97,3 @@ if(BUILD_EXTENSION_SAMPLES) add_subdirectory( 15_mutablecommandbufferasserts ) endif() -add_subdirectory( 99_matrixexperiments ) -add_subdirectory( 99_matrixexperimentsi8 ) -add_subdirectory( 99_matrixexperimentstf32 ) From af28503d6ea85475b7cc56052a73ce36356c315d Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Sun, 22 Feb 2026 21:11:32 -0800 Subject: [PATCH 92/99] update copyright, add README --- .../20_matrixexperiments-bf16/CMakeLists.txt | 2 +- samples/20_matrixexperiments-bf16/README.md | 60 +++++++++++++++++++ samples/20_matrixexperiments-bf16/main.cpp | 2 +- .../matrix_helpers_bf16.cl | 2 +- .../matrix_kernel_tiled_bf16.cl | 2 +- .../matrix_kernels_bf16.cl | 2 +- .../20_matrixexperiments-i8/CMakeLists.txt | 2 +- samples/20_matrixexperiments-i8/README.md | 60 +++++++++++++++++++ samples/20_matrixexperiments-i8/main.cpp | 2 +- .../matrix_helpers_i8.cl | 4 +- .../matrix_kernels_i8.cl | 2 +- .../20_matrixexperiments-tf32/CMakeLists.txt | 2 +- samples/20_matrixexperiments-tf32/README.md | 58 ++++++++++++++++++ samples/20_matrixexperiments-tf32/main.cpp | 2 +- .../matrix_helpers_tf32.cl | 2 +- .../matrix_kernel_tiled_tf32.cl | 2 +- .../matrix_kernels_tf32.cl | 2 +- 17 files changed, 193 insertions(+), 15 deletions(-) create mode 100644 samples/20_matrixexperiments-bf16/README.md create mode 100644 samples/20_matrixexperiments-i8/README.md create mode 100644 samples/20_matrixexperiments-tf32/README.md diff --git a/samples/20_matrixexperiments-bf16/CMakeLists.txt b/samples/20_matrixexperiments-bf16/CMakeLists.txt index 0acf9ca0..0c08c2bc 100644 --- a/samples/20_matrixexperiments-bf16/CMakeLists.txt +++ b/samples/20_matrixexperiments-bf16/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025 Ben Ashbaugh +# Copyright (c) 2024-2026 Ben Ashbaugh # # SPDX-License-Identifier: MIT diff --git a/samples/20_matrixexperiments-bf16/README.md b/samples/20_matrixexperiments-bf16/README.md new file mode 100644 index 00000000..793bba22 --- /dev/null +++ b/samples/20_matrixexperiments-bf16/README.md @@ -0,0 +1,60 @@ +# matrixexperiments-bf16 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 16-bit `bfloat16` data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +A good place to start for some devices is: + +```sh +./matrixexperiments-bf16 -m4096 --options="-DSGS_PER_WG_X=4 -DSGS_PER_WG_Y=8 -DKK=2 -cl-intel-256-GRF-per-thread" --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_bfloat16_conversions +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate +* cl_intel_subgroups +* cl_intel_subgroups_short + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_bf16.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index 04a784e1..7cfb2c0b 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl index 49961017..f9bf93f2 100644 --- a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl index ead10880..34d67b24 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl index 2e2c46a9..e799f41d 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-i8/CMakeLists.txt b/samples/20_matrixexperiments-i8/CMakeLists.txt index a8acd899..cc59c28d 100644 --- a/samples/20_matrixexperiments-i8/CMakeLists.txt +++ b/samples/20_matrixexperiments-i8/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025 Ben Ashbaugh +# Copyright (c) 2024-2026 Ben Ashbaugh # # SPDX-License-Identifier: MIT diff --git a/samples/20_matrixexperiments-i8/README.md b/samples/20_matrixexperiments-i8/README.md new file mode 100644 index 00000000..8fe63a4b --- /dev/null +++ b/samples/20_matrixexperiments-i8/README.md @@ -0,0 +1,60 @@ +# matrixexperiments-i8 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 8-bit integer data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +Note, these kernels are not as highly tuned as the kernels for `bfloat16` and `tf32`! +A good place to start for some devices is: + +```sh +./matrixexperiments-i8 -m4096 --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate +* cl_intel_subgroups +* cl_intel_subgroups_char + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_bf16.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp index 9e860c71..91fe2016 100644 --- a/samples/20_matrixexperiments-i8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index 27ef25cf..360114b6 100644 --- a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ @@ -634,4 +634,4 @@ void store_c_rowmajor_int32_m8_nx(global int* C, int8 v, int rowStart, int colSt intel_sub_group_block_write(C_ui + offset, v_ui.s7); offset += stride; } -#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) diff --git a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl index 10100ded..b16f8d94 100644 --- a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-tf32/CMakeLists.txt b/samples/20_matrixexperiments-tf32/CMakeLists.txt index 0329e267..fe34bea2 100644 --- a/samples/20_matrixexperiments-tf32/CMakeLists.txt +++ b/samples/20_matrixexperiments-tf32/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025 Ben Ashbaugh +# Copyright (c) 2024-2026 Ben Ashbaugh # # SPDX-License-Identifier: MIT diff --git a/samples/20_matrixexperiments-tf32/README.md b/samples/20_matrixexperiments-tf32/README.md new file mode 100644 index 00000000..62ffadf5 --- /dev/null +++ b/samples/20_matrixexperiments-tf32/README.md @@ -0,0 +1,58 @@ +# matrixexperiments-tf32 + +## Sample Purpose + +This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 32-bit `tf32` data. +The sample includes many different implementations: + +1. The "naive" implementation is a very simple implementation. +It is not very fast, but it is easy to understand, and it has no extension dependencies so it will run on many devices. +2. The "dpas" kernels use sub-group extensions to improve performance. +On some devices, they will also use specialized matrix multiplication extensions to further improve performance. +Because these kernels require certain extensions or a specific sub-group size, they may not run on all devices. +3. The "dpas blockread" kernels use additional sub-group extensions to further improve performance. + +Most of the optimized kernels operate on fixed size tiles of matrix data. +For some of these kernels, parameters such as the number of matrix tiles per-sub-group or the number of sub-groups per work-group may be modified via program build options. +Experiment with different options to see what performs the best! + +A good place to start for some devices is: + +```sh +./matrixexperiments-tf32 -m4096 --options="-DSGS_PER_WG_X=4 -DSGS_PER_WG_Y=8 -DKK=2 -cl-intel-256-GRF-per-thread" --zero +``` + +## Key APIs and Concepts + +This sample will optionally use the following OpenCL extensions: + +* cl_intel_required_subgroup_size +* cl_intel_split_work_group_barrier +* cl_intel_subgroup_2d_block_io +* cl_intel_subgroup_matrix_multiply_accumulate_tf32 +* cl_intel_subgroups + +## Command Line Options + +| Option | Default Value | Description | +|:--|:-:|:--| +| `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. +| `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. +| `--file ` | `matrix_kernels_tf32.cl` | Specify the name of the file with the OpenCL kernel source. +| `--options ` | None | Specify optional program build options. +| `--matrixsize ` | 512 | Specify the dimensions of the matrix. +| `--iterations ` | 16 | Specify the number of iterations for performance testing. +| `--validate` | n/a | Validate results for correctness. +| `--zero` | n/a | Initialize all matrices to zero. +| `--identity` | n/a | Initialize all matrices to to one. +| `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. +| `--emulate` | n/a | Do not use specialized matrix multiplication extensions. +| `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. +| `--skipinit` | n/a | Skip initialization of source matrices. +| `--roundrobin` | n/a | Use round robin thread scheduling. +| `--threshold ` | 0.01 | Set the threshold used when validating results. +| `--mask ` | ~0 | Set a mask to only run a subset of tests. + +By default, the source matrices are populated with random data. +When validating results, it is recommended to use either "fixed" or "identity" data. +For best performance, use "zero" data". diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index ab6cae87..4bf4bba8 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -1,5 +1,5 @@ /* -// Copyright (c) 2019-2024 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl index bf81a4b3..da7c1f8a 100644 --- a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl index c93e717e..ffc19445 100644 --- a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ diff --git a/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl index ea4526db..8bcbd575 100644 --- a/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl @@ -1,5 +1,5 @@ /* -// Copyright (c) 2024-2025 Ben Ashbaugh +// Copyright (c) 2024-2026 Ben Ashbaugh // // SPDX-License-Identifier: MIT */ From 7f5764c079fcf3a66da423ccfb1d9a3a1be72d4c Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 23 Feb 2026 15:55:37 -0800 Subject: [PATCH 93/99] remove warning when split barriers are unsupported --- samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl | 3 --- samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl | 3 --- 2 files changed, 6 deletions(-) diff --git a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl index 34d67b24..afde3730 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -21,9 +21,6 @@ #endif #if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) -#if !defined(cl_intel_split_work_group_barrier) -#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" -#endif #define split_barrier_arrive() #define split_barrier_wait() #else diff --git a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl index ffc19445..e5fb52ee 100644 --- a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl @@ -21,9 +21,6 @@ #endif #if !defined(cl_intel_split_work_group_barrier) || defined(NO_SPLIT_BARRIERS) -#if !defined(cl_intel_split_work_group_barrier) -#warning "Unexpected: cl_intel_split_work_group_barrier is not supported?" -#endif #define split_barrier_arrive() #define split_barrier_wait() #else From 4e31c01ac0349f9b776404608737782160955a0a Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 2 Mar 2026 16:59:08 -0800 Subject: [PATCH 94/99] fixes for CPU and more --- samples/20_matrixexperiments-bf16/main.cpp | 63 +++++++++++--- .../matrix_helpers_bf16.cl | 32 +++---- .../matrix_kernel_tiled_bf16.cl | 4 +- .../matrix_kernels_bf16.cl | 14 ++-- samples/20_matrixexperiments-i8/main.cpp | 47 +++++++++-- .../matrix_helpers_i8.cl | 84 +++++++++---------- .../matrix_kernels_i8.cl | 30 +++---- samples/20_matrixexperiments-tf32/main.cpp | 37 ++++++-- .../matrix_kernels_tf32.cl | 2 +- 9 files changed, 207 insertions(+), 106 deletions(-) diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index 7cfb2c0b..c0c4dce4 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -76,6 +76,12 @@ static size_t findMinSubGroupSize(cl::Device& device) return 0; } +static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) +{ + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); +} + static void setRoundRobin(cl::Kernel& kernel) { constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; @@ -175,6 +181,23 @@ static float hw_time(cl::Event& event) return ns / 1e9f; } +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + static void bfloat16_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, @@ -187,6 +210,8 @@ static void bfloat16_naive( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -201,7 +226,7 @@ static void bfloat16_naive( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -237,6 +262,8 @@ static void bfloat16_dpas_rowmajor( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -251,7 +278,7 @@ static void bfloat16_dpas_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -293,6 +320,8 @@ static void bfloat16_dpas_rowmajor_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -307,7 +336,7 @@ static void bfloat16_dpas_rowmajor_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -343,6 +372,8 @@ static void bfloat16_dpas_vnni( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -357,7 +388,7 @@ static void bfloat16_dpas_vnni( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -399,6 +430,8 @@ static void bfloat16_dpas_vnni_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -413,7 +446,7 @@ static void bfloat16_dpas_vnni_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -449,6 +482,8 @@ static void bfloat16_dpas_blockread_rowmajor( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -466,7 +501,7 @@ static void bfloat16_dpas_blockread_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -508,6 +543,8 @@ static void bfloat16_dpas_blockread_rowmajor_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -525,7 +562,7 @@ static void bfloat16_dpas_blockread_rowmajor_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -561,6 +598,8 @@ static void bfloat16_dpas_blockread_vnni( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -578,7 +617,7 @@ static void bfloat16_dpas_blockread_vnni( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -620,6 +659,8 @@ static void bfloat16_dpas_blockread_vnni_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -637,7 +678,7 @@ static void bfloat16_dpas_blockread_vnni_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -729,7 +770,7 @@ int main(int argc, char** argv) auto minSubGroupSize = findMinSubGroupSize(device); - bool has_simd8 = minSubGroupSize == 8; + bool has_sg8 = supportsSubgroupSize(device, 8); bool emulate_tN8 = true; bool emulate_tN16 = true; if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { @@ -741,7 +782,7 @@ int main(int argc, char** argv) } } - buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); + buildOptions += " -DHAS_SG8=" + std::to_string(has_sg8); buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); diff --git a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl index f9bf93f2..7dcb2e27 100644 --- a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -155,22 +155,22 @@ float emu_sub_group_bf16_bf16_matrix_mad_k16(short a, int8 b, float acc) { float res = acc; - res = fma(bf16_to_fp32(sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); - res = fma(bf16_to_fp32(sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 0)), bf16_to_fp32(as_ushort2(b.s0).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 1)), bf16_to_fp32(as_ushort2(b.s0).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 2)), bf16_to_fp32(as_ushort2(b.s1).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 3)), bf16_to_fp32(as_ushort2(b.s1).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 4)), bf16_to_fp32(as_ushort2(b.s2).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 5)), bf16_to_fp32(as_ushort2(b.s2).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 6)), bf16_to_fp32(as_ushort2(b.s3).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 7)), bf16_to_fp32(as_ushort2(b.s3).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 8)), bf16_to_fp32(as_ushort2(b.s4).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 9)), bf16_to_fp32(as_ushort2(b.s4).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 10)), bf16_to_fp32(as_ushort2(b.s5).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 11)), bf16_to_fp32(as_ushort2(b.s5).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 12)), bf16_to_fp32(as_ushort2(b.s6).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 13)), bf16_to_fp32(as_ushort2(b.s6).y), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 14)), bf16_to_fp32(as_ushort2(b.s7).x), res); + res = fma(bf16_to_fp32(intel_sub_group_broadcast(a, 15)), bf16_to_fp32(as_ushort2(b.s7).y), res); return res; } diff --git a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl index afde3730..d76ee526 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -64,7 +64,7 @@ void HELPER_NAME(btile_load_packed, MM, NN)(global ushort* B, int tN, int N, int } } -#if HAS_SIMD8 +#if HAS_SG8 void HELPER_NAME(atile_prefetch_rowmajor_sg8, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) { @@ -236,7 +236,7 @@ kernel void MM_KERNEL_NAME(bfloat16_dpas_vnni_tiled, 8, 8, MM, NN)(global float* } } -#endif // HAS_SIMD8 +#endif // HAS_SG8 void HELPER_NAME(atile_prefetch_rowmajor, MM, NN)(global ushort* A, int tM, int K, int m, int prefetch_k) { diff --git a/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl index e799f41d..b98711b6 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernels_bf16.cl @@ -38,7 +38,7 @@ kernel void bfloat16_naive(global float* C, global ushort* A, global ushort* B, #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) -#if HAS_SIMD8 +#if HAS_SG8 // rowmajor kernels: @@ -212,9 +212,9 @@ kernel void bfloat16_dpas_vnni_m8_n8(global float* C, global ushort* A, global u store_c_rowmajor_fp32_8rNc(C, sum, m, n, N); } -#endif // HAS_SIMD8 +#endif // HAS_SG8 -// rowmajor krenels: +// rowmajor kernels: __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, global ushort* B, int K) @@ -224,7 +224,7 @@ kernel void bfloat16_dpas_rowmajor_m1_n16(global float* C, global ushort* A, glo const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float sum = 0; for (int k = 0; k < K; k += tK) { @@ -245,7 +245,7 @@ kernel void bfloat16_dpas_rowmajor_m2_n16(global float* C, global ushort* A, glo const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float2 sum = 0; for (int k = 0; k < K; k += tK) { @@ -266,7 +266,7 @@ kernel void bfloat16_dpas_rowmajor_m4_n16(global float* C, global ushort* A, glo const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float4 sum = 0; for (int k = 0; k < K; k += tK) { @@ -287,7 +287,7 @@ kernel void bfloat16_dpas_rowmajor_m8_n16(global float* C, global ushort* A, glo const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; float8 sum = 0; for (int k = 0; k < K; k += tK) { diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp index 91fe2016..8bea9828 100644 --- a/samples/20_matrixexperiments-i8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -76,6 +76,12 @@ static size_t findMinSubGroupSize(cl::Device& device) return 0; } +static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) +{ + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); +} + static void setRoundRobin(cl::Kernel& kernel) { constexpr cl_kernel_exec_info CL_KERNEL_EXEC_INFO_THREAD_ARBITRATION_POLICY_INTEL = 0x10025; @@ -170,6 +176,23 @@ static float hw_time(cl::Event& event) return ns / 1e9f; } +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + static void i8_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, @@ -182,6 +205,8 @@ static void i8_naive( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -196,7 +221,7 @@ static void i8_naive( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -232,6 +257,8 @@ static void i8_dpas_rowmajor( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -246,7 +273,7 @@ static void i8_dpas_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -282,6 +309,8 @@ static void i8_dpas_vnni( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -296,7 +325,7 @@ static void i8_dpas_vnni( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -334,6 +363,8 @@ static void i8_dpas_blockread_rowmajor( } else if (K < 64 || N < 64) { printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -351,7 +382,7 @@ static void i8_dpas_blockread_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -389,6 +420,8 @@ static void i8_dpas_blockread_vnni( } else if (K < 64 || N < 64/4) { printf("matrix pitch for block reads must be >= 64 bytes.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -406,7 +439,7 @@ static void i8_dpas_blockread_vnni( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -498,7 +531,7 @@ int main(int argc, char** argv) auto minSubGroupSize = findMinSubGroupSize(device); - bool has_simd8 = minSubGroupSize == 8; + bool has_sg8 = supportsSubgroupSize(device, 8); bool emulate_tN8 = true; bool emulate_tN16 = true; if (!emulate && checkDeviceForExtension(device, "cl_intel_subgroup_matrix_multiply_accumulate")) { @@ -510,7 +543,7 @@ int main(int argc, char** argv) } } - buildOptions += " -DHAS_SIMD8=" + std::to_string(has_simd8); + buildOptions += " -DHAS_SG8=" + std::to_string(has_sg8); buildOptions += " -DEMULATE_tN8=" + std::to_string(emulate_tN8); buildOptions += " -DEMULATE_tN16=" + std::to_string(emulate_tN16); diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index 360114b6..d2a0e055 100644 --- a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -55,7 +55,7 @@ int8 activation(int8 i) #define __builtin_expect(x) #endif -#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char) typedef global char* global_aligned_char_ptr __attribute__((align_value(4))); @@ -171,45 +171,45 @@ int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) { float res = acc; - res = as_char2(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; - res = as_char2(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; - res = as_char2(sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res; - res = as_char2(sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res; - - res = as_char2(sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res; - res = as_char2(sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res; - res = as_char2(sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res; - res = as_char2(sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res; - - res = as_char2(sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res; - res = as_char2(sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res; - res = as_char2(sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res; - res = as_char2(sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res; - - res = as_char2(sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res; - res = as_char2(sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res; - res = as_char2(sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res; - res = as_char2(sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res; - - res = as_char2(sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res; - res = as_char2(sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res; - res = as_char2(sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res; - res = as_char2(sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res; - - res = as_char2(sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res; - res = as_char2(sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res; - res = as_char2(sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res; - res = as_char2(sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res; - - res = as_char2(sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res; - res = as_char2(sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res; - res = as_char2(sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res; - res = as_char2(sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res; - - res = as_char2(sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res; - res = as_char2(sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res; - res = as_char2(sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res; - res = as_char2(sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res; + res = as_char2(intel_sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; + res = as_char2(intel_sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; + res = as_char2(intel_sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res; + res = as_char2(intel_sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res; + res = as_char2(intel_sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res; + res = as_char2(intel_sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res; + res = as_char2(intel_sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res; + res = as_char2(intel_sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res; + res = as_char2(intel_sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res; + res = as_char2(intel_sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res; + res = as_char2(intel_sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res; + res = as_char2(intel_sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res; + res = as_char2(intel_sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res; + res = as_char2(intel_sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res; + res = as_char2(intel_sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res; + res = as_char2(intel_sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res; + res = as_char2(intel_sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res; + res = as_char2(intel_sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res; + res = as_char2(intel_sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res; + res = as_char2(intel_sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res; + res = as_char2(intel_sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res; + res = as_char2(intel_sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res; + + res = as_char2(intel_sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res; + res = as_char2(intel_sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res; + res = as_char2(intel_sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res; + res = as_char2(intel_sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res; return res; } @@ -459,9 +459,9 @@ void prefetch_a_rowmajor_d16_m8v2_k16v2_sg16(global ushort* A, int rowStart, int #endif // K rows x N columns: -// Each work-item loads K values and converts to VNNI. +// Each work-item loads K values and packs into 32-bits. // Stride is in units of elements. -int8 load_b_rowmajor_d8_k32_nx(global char* B, int rowStart, int colStart, int stride) +int8 load_b_rowmajor_8b_32rNc(global char* B, int rowStart, int colStart, int stride) { int8 ret; diff --git a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl index b16f8d94..6bec5f2f 100644 --- a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl @@ -38,7 +38,7 @@ kernel void i8_naive(global int* C, global char* A, global char* B, int K) #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) -#if HAS_SIMD8 +#if HAS_SG8 // rowmajor kernels: @@ -55,7 +55,7 @@ kernel void i8_dpas_rowmajor_m1_n8(global int* C, global char* A, global char* B int sum = 0; for (int k = 0; k < K; k += tK) { int aData = load_a_rowmajor_d8_m1_k32_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } @@ -76,7 +76,7 @@ kernel void i8_dpas_rowmajor_m2_n8(global int* C, global char* A, global char* B int2 sum = 0; for (int k = 0; k < K; k += tK) { int2 aData = load_a_rowmajor_d8_m2_k32_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } @@ -97,7 +97,7 @@ kernel void i8_dpas_rowmajor_m4_n8(global int* C, global char* A, global char* B int4 sum = 0; for (int k = 0; k < K; k += tK) { int4 aData = load_a_rowmajor_d8_m4_k32_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } @@ -118,7 +118,7 @@ kernel void i8_dpas_rowmajor_m8_n8(global int* C, global char* A, global char* B int8 sum = 0; for (int k = 0; k < K; k += tK) { int8 aData = load_a_rowmajor_d8_m8_k32_sg8(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg8(aData, bData, sum); } @@ -212,9 +212,9 @@ kernel void i8_dpas_vnni_m8_n8(global int* C, global char* A, global char* B, in store_c_rowmajor_int32_m8_nx(C, sum, m, n, N); } -#endif // HAS_SIMD8 +#endif // HAS_SG8 -// rowmajor krenels: +// rowmajor kernels: __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void i8_dpas_rowmajor_m1_n16(global int* C, global char* A, global char* B, int K) @@ -224,12 +224,12 @@ kernel void i8_dpas_rowmajor_m1_n16(global int* C, global char* A, global char* const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; int sum = 0; for (int k = 0; k < K; k += tK) { short aData = load_a_rowmajor_d8_m1_k32_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } @@ -245,12 +245,12 @@ kernel void i8_dpas_rowmajor_m2_n16(global int* C, global char* A, global char* const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; int2 sum = 0; for (int k = 0; k < K; k += tK) { short2 aData = load_a_rowmajor_d8_m2_k32_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } @@ -266,12 +266,12 @@ kernel void i8_dpas_rowmajor_m4_n16(global int* C, global char* A, global char* const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; int4 sum = 0; for (int k = 0; k < K; k += tK) { short4 aData = load_a_rowmajor_d8_m4_k32_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } @@ -287,12 +287,12 @@ kernel void i8_dpas_rowmajor_m8_n16(global int* C, global char* A, global char* const int tN = 16; const int N = get_global_size(0); const int m = get_group_id(1) * tM; - const int n = get_group_id(0) * get_local_size(0); + const int n = get_group_id(0) * tN; int8 sum = 0; for (int k = 0; k < K; k += tK) { short8 aData = load_a_rowmajor_d8_m8_k32_sg16(A, m, k, K); - int8 bData = load_b_rowmajor_d8_k32_nx(B, k, n, N); + int8 bData = load_b_rowmajor_8b_32rNc(B, k, n, N); sum = mat_mul_sg16(aData, bData, sum); } diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index 4bf4bba8..fcfcc2fc 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -179,6 +179,23 @@ static float hw_time(cl::Event& event) return ns / 1e9f; } +static cl::NDRange getRequiredLocalWorkSize(cl::Kernel& kernel, cl::CommandQueue queue) +{ + // Note: This shouldn't be necessary, and the OpenCL implementation should + // automatically choose the required local work-group size when the local + // work-group size is `nullptr`. This is not working for some OpenCL + // implementations, though, so we will just query and use the required local + // work-group size explicitly. + auto device = queue.getInfo(); + auto reqd_wgs = kernel.getWorkGroupInfo(device); + + if (reqd_wgs[0] > 0 && reqd_wgs[1] > 0 && reqd_wgs[2] > 0) { + return cl::NDRange(reqd_wgs[0], reqd_wgs[1], reqd_wgs[2]); + } + + return cl::NullRange; +} + static void tf32_naive( cl::Context& context, cl::Program& program, cl::CommandQueue& queue, cl::Buffer& C, cl::Buffer& A, cl::Buffer& B, @@ -191,6 +208,8 @@ static void tf32_naive( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -205,7 +224,7 @@ static void tf32_naive( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -241,6 +260,8 @@ static void tf32_dpas_rowmajor( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -255,7 +276,7 @@ static void tf32_dpas_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -297,6 +318,8 @@ static void tf32_dpas_rowmajor_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -311,7 +334,7 @@ static void tf32_dpas_rowmajor_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -347,6 +370,8 @@ static void tf32_dpas_blockread_rowmajor( if (kernel() == nullptr) { printf("unsupported.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -364,7 +389,7 @@ static void tf32_dpas_blockread_rowmajor( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N, M/tM}, cl::NullRange, nullptr, &event); + cl::NDRange{N, M/tM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -406,6 +431,8 @@ static void tf32_dpas_blockread_rowmajor_tiled( } else if (tN * NN > N) { printf("N is too small.\n"); } else { + const cl::NDRange localWorkSize = getRequiredLocalWorkSize(kernel, queue); + kernel.setArg(0, C); kernel.setArg(1, A); kernel.setArg(2, B); @@ -423,7 +450,7 @@ static void tf32_dpas_blockread_rowmajor_tiled( cl::Event event; auto start = test_clock::now(); queue.enqueueNDRangeKernel(kernel, cl::NullRange, - cl::NDRange{N/NN, M/tM/MM}, cl::NullRange, nullptr, &event); + cl::NDRange{N/NN, M/tM/MM}, localWorkSize, nullptr, &event); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; diff --git a/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl index 8bcbd575..bdac8e37 100644 --- a/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernels_tf32.cl @@ -32,7 +32,7 @@ kernel void tf32_naive(global float* C, global float* A, global float* B, int K) #if defined(cl_intel_subgroups) && defined(cl_intel_required_subgroup_size) -// rowmajor krenels: +// rowmajor kernels: __attribute__((intel_reqd_sub_group_size(16))) __attribute__((reqd_work_group_size(16, 1, 1))) kernel void tf32_dpas_rowmajor_m1_n16(global float* C, global float* A, global float* B, int K) From ada8a4a46bc379fe892b6294d370ace5791075d9 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 3 Mar 2026 20:43:18 -0800 Subject: [PATCH 95/99] switch to kernels that use integer dot products --- .../matrix_helpers_i8.cl | 113 +++++------------- 1 file changed, 33 insertions(+), 80 deletions(-) diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index d2a0e055..9aa6ffcb 100644 --- a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -55,6 +55,23 @@ int8 activation(int8 i) #define __builtin_expect(x) #endif +#if defined(__opencl_c_integer_dot_product_input_4x8bit_packed) +#define dp4 dot_4x8packed_ss_int +#else +#define dp4 emu_dot_4x8packed_ss_int + +int emu_dot_4x8packed_ss_int(const uint a, const uint b) +{ + const char4 a_c4 = as_char4(a); + const char4 b_c4 = as_char4(b); + + return a_c4.x * b_c4.x + + a_c4.y * b_c4.y + + a_c4.z * b_c4.z + + a_c4.w * b_c4.w; +} +#endif + #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char) typedef global char* global_aligned_char_ptr __attribute__((align_value(4))); @@ -79,47 +96,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(int a, int8 b, int acc) { int res = acc; - // TODO: this could use integer dot products instead? - - res = as_char4(sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; - res = as_char4(sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; - res = as_char4(sub_group_broadcast(a, 0)).z * as_char4(b.s0).z + res; - res = as_char4(sub_group_broadcast(a, 0)).w * as_char4(b.s0).w + res; - - res = as_char4(sub_group_broadcast(a, 1)).x * as_char4(b.s1).x + res; - res = as_char4(sub_group_broadcast(a, 1)).y * as_char4(b.s1).y + res; - res = as_char4(sub_group_broadcast(a, 1)).z * as_char4(b.s1).z + res; - res = as_char4(sub_group_broadcast(a, 1)).w * as_char4(b.s1).w + res; - - res = as_char4(sub_group_broadcast(a, 2)).x * as_char4(b.s2).x + res; - res = as_char4(sub_group_broadcast(a, 2)).y * as_char4(b.s2).y + res; - res = as_char4(sub_group_broadcast(a, 2)).z * as_char4(b.s2).z + res; - res = as_char4(sub_group_broadcast(a, 2)).w * as_char4(b.s2).w + res; - - res = as_char4(sub_group_broadcast(a, 3)).x * as_char4(b.s3).x + res; - res = as_char4(sub_group_broadcast(a, 3)).y * as_char4(b.s3).y + res; - res = as_char4(sub_group_broadcast(a, 3)).z * as_char4(b.s3).z + res; - res = as_char4(sub_group_broadcast(a, 3)).w * as_char4(b.s3).w + res; - - res = as_char4(sub_group_broadcast(a, 4)).x * as_char4(b.s4).x + res; - res = as_char4(sub_group_broadcast(a, 4)).y * as_char4(b.s4).y + res; - res = as_char4(sub_group_broadcast(a, 4)).z * as_char4(b.s4).z + res; - res = as_char4(sub_group_broadcast(a, 4)).w * as_char4(b.s4).w + res; - - res = as_char4(sub_group_broadcast(a, 5)).x * as_char4(b.s5).x + res; - res = as_char4(sub_group_broadcast(a, 5)).y * as_char4(b.s5).y + res; - res = as_char4(sub_group_broadcast(a, 5)).z * as_char4(b.s5).z + res; - res = as_char4(sub_group_broadcast(a, 5)).w * as_char4(b.s5).w + res; - - res = as_char4(sub_group_broadcast(a, 6)).x * as_char4(b.s6).x + res; - res = as_char4(sub_group_broadcast(a, 6)).y * as_char4(b.s6).y + res; - res = as_char4(sub_group_broadcast(a, 6)).z * as_char4(b.s6).z + res; - res = as_char4(sub_group_broadcast(a, 6)).w * as_char4(b.s6).w + res; - - res = as_char4(sub_group_broadcast(a, 7)).x * as_char4(b.s7).x + res; - res = as_char4(sub_group_broadcast(a, 7)).y * as_char4(b.s7).y + res; - res = as_char4(sub_group_broadcast(a, 7)).z * as_char4(b.s7).z + res; - res = as_char4(sub_group_broadcast(a, 7)).w * as_char4(b.s7).w + res; + res = dp4(sub_group_broadcast(a, 0), b.s0) + res; + res = dp4(sub_group_broadcast(a, 1), b.s1) + res; + res = dp4(sub_group_broadcast(a, 2), b.s2) + res; + res = dp4(sub_group_broadcast(a, 3), b.s3) + res; + res = dp4(sub_group_broadcast(a, 4), b.s4) + res; + res = dp4(sub_group_broadcast(a, 5), b.s5) + res; + res = dp4(sub_group_broadcast(a, 6), b.s6) + res; + res = dp4(sub_group_broadcast(a, 7), b.s7) + res; return res; } @@ -171,45 +155,14 @@ int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) { float res = acc; - res = as_char2(intel_sub_group_broadcast(a, 0)).x * as_char4(b.s0).x + res; - res = as_char2(intel_sub_group_broadcast(a, 0)).y * as_char4(b.s0).y + res; - res = as_char2(intel_sub_group_broadcast(a, 1)).x * as_char4(b.s0).z + res; - res = as_char2(intel_sub_group_broadcast(a, 1)).y * as_char4(b.s0).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 2)).x * as_char4(b.s1).x + res; - res = as_char2(intel_sub_group_broadcast(a, 2)).y * as_char4(b.s1).y + res; - res = as_char2(intel_sub_group_broadcast(a, 3)).x * as_char4(b.s1).z + res; - res = as_char2(intel_sub_group_broadcast(a, 3)).y * as_char4(b.s1).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 4)).x * as_char4(b.s2).x + res; - res = as_char2(intel_sub_group_broadcast(a, 4)).y * as_char4(b.s2).y + res; - res = as_char2(intel_sub_group_broadcast(a, 5)).x * as_char4(b.s2).z + res; - res = as_char2(intel_sub_group_broadcast(a, 5)).y * as_char4(b.s2).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 6)).x * as_char4(b.s3).x + res; - res = as_char2(intel_sub_group_broadcast(a, 6)).y * as_char4(b.s3).y + res; - res = as_char2(intel_sub_group_broadcast(a, 7)).x * as_char4(b.s3).z + res; - res = as_char2(intel_sub_group_broadcast(a, 7)).y * as_char4(b.s3).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 8)).x * as_char4(b.s4).x + res; - res = as_char2(intel_sub_group_broadcast(a, 8)).y * as_char4(b.s4).y + res; - res = as_char2(intel_sub_group_broadcast(a, 9)).x * as_char4(b.s4).z + res; - res = as_char2(intel_sub_group_broadcast(a, 9)).y * as_char4(b.s4).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 10)).x * as_char4(b.s5).x + res; - res = as_char2(intel_sub_group_broadcast(a, 10)).y * as_char4(b.s5).y + res; - res = as_char2(intel_sub_group_broadcast(a, 11)).x * as_char4(b.s5).z + res; - res = as_char2(intel_sub_group_broadcast(a, 11)).y * as_char4(b.s5).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 12)).x * as_char4(b.s6).x + res; - res = as_char2(intel_sub_group_broadcast(a, 12)).y * as_char4(b.s6).y + res; - res = as_char2(intel_sub_group_broadcast(a, 13)).x * as_char4(b.s6).z + res; - res = as_char2(intel_sub_group_broadcast(a, 13)).y * as_char4(b.s6).w + res; - - res = as_char2(intel_sub_group_broadcast(a, 14)).x * as_char4(b.s7).x + res; - res = as_char2(intel_sub_group_broadcast(a, 14)).y * as_char4(b.s7).y + res; - res = as_char2(intel_sub_group_broadcast(a, 15)).x * as_char4(b.s7).z + res; - res = as_char2(intel_sub_group_broadcast(a, 15)).y * as_char4(b.s7).w + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 0), sub_group_broadcast(a, 1))), b.s0) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 2), sub_group_broadcast(a, 3))), b.s1) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 4), sub_group_broadcast(a, 5))), b.s2) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 6), sub_group_broadcast(a, 7))), b.s3) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 8), sub_group_broadcast(a, 9))), b.s4) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 10), sub_group_broadcast(a, 11))), b.s5) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 12), sub_group_broadcast(a, 13))), b.s6) + res; + res = dp4(as_uint((short2)(sub_group_broadcast(a, 14), sub_group_broadcast(a, 15))), b.s7) + res; return res; } From 2ae7f9b4fccec5974e2c70c7642641ebcef0bf81 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Mon, 1 Jun 2026 22:47:57 -0700 Subject: [PATCH 96/99] minor fixes --- include/bfloat16.hpp | 4 ++-- include/util.hpp | 6 +----- samples/20_matrixexperiments-bf16/README.md | 6 +++--- samples/20_matrixexperiments-bf16/main.cpp | 1 + samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl | 1 + .../20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl | 2 +- samples/20_matrixexperiments-i8/README.md | 8 ++++---- samples/20_matrixexperiments-i8/main.cpp | 1 + samples/20_matrixexperiments-i8/matrix_helpers_i8.cl | 3 ++- samples/20_matrixexperiments-i8/matrix_kernels_i8.cl | 4 ++-- samples/20_matrixexperiments-tf32/README.md | 6 +++--- samples/20_matrixexperiments-tf32/main.cpp | 1 + samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl | 1 + .../20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl | 2 +- 14 files changed, 24 insertions(+), 22 deletions(-) diff --git a/include/bfloat16.hpp b/include/bfloat16.hpp index 5e9541bd..6a36fe69 100644 --- a/include/bfloat16.hpp +++ b/include/bfloat16.hpp @@ -48,10 +48,10 @@ class bfloat16 { operator float() const { return to_float(value); } // Logical operators (!,||,&&) are covered if we can cast to bool - explicit operator bool() { return to_float(value) != 0.0f; } + explicit operator bool() const { return to_float(value) != 0.0f; } // Unary minus operator overloading - friend bfloat16 operator-(bfloat16 &lhs) { + friend bfloat16 operator-(const bfloat16 &lhs) { return -to_float(lhs.value); } diff --git a/include/util.hpp b/include/util.hpp index 2b65600f..68f3014d 100644 --- a/include/util.hpp +++ b/include/util.hpp @@ -8,6 +8,7 @@ #include #include +#include #include static cl_version getDeviceOpenCLVersion( @@ -79,11 +80,6 @@ static std::string readStringFromFile( return ""; } - size_t filesize = 0; - is.seekg(0, std::ios::end); - filesize = (size_t)is.tellg(); - is.seekg(0, std::ios::beg); - std::string source{ std::istreambuf_iterator(is), std::istreambuf_iterator() }; diff --git a/samples/20_matrixexperiments-bf16/README.md b/samples/20_matrixexperiments-bf16/README.md index 793bba22..893f32c7 100644 --- a/samples/20_matrixexperiments-bf16/README.md +++ b/samples/20_matrixexperiments-bf16/README.md @@ -2,7 +2,7 @@ ## Sample Purpose -This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 16-bit `bfloat16` data. +This sample demonstrates various techniques to perform a large matrix multiplication where the matrix elements contain 16-bit `bfloat16` data. The sample includes many different implementations: 1. The "naive" implementation is a very simple implementation. @@ -46,7 +46,7 @@ This sample will optionally use the following OpenCL extensions: | `--iterations ` | 16 | Specify the number of iterations for performance testing. | `--validate` | n/a | Validate results for correctness. | `--zero` | n/a | Initialize all matrices to zero. -| `--identity` | n/a | Initialize all matrices to to one. +| `--identity` | n/a | Initialize all matrices to one. | `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. | `--emulate` | n/a | Do not use specialized matrix multiplication extensions. | `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. @@ -57,4 +57,4 @@ This sample will optionally use the following OpenCL extensions: By default, the source matrices are populated with random data. When validating results, it is recommended to use either "fixed" or "identity" data. -For best performance, use "zero" data". +For best performance, use "zero" data. diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index c0c4dce4..e90d691e 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -758,6 +758,7 @@ int main(int argc, char** argv) if (deviceIndex >= devices.size()) { printf("Requested device index is %d, but only %zu devices were found.\n", deviceIndex, devices.size()); + return -1; } cl::Device& device = devices[deviceIndex]; diff --git a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl index 7dcb2e27..a792bee5 100644 --- a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -43,6 +43,7 @@ float4 activation(float4 f) return res; } +__attribute__((overloadable)) float8 activation(float8 f) { float8 res; diff --git a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl index d76ee526..892d7fe0 100644 --- a/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_kernel_tiled_bf16.cl @@ -5,7 +5,7 @@ */ #if !defined(tK) -#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the element type, likely 16 or 32." #endif #if !defined(MM) diff --git a/samples/20_matrixexperiments-i8/README.md b/samples/20_matrixexperiments-i8/README.md index 8fe63a4b..53ac92f5 100644 --- a/samples/20_matrixexperiments-i8/README.md +++ b/samples/20_matrixexperiments-i8/README.md @@ -2,7 +2,7 @@ ## Sample Purpose -This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 8-bit integer data. +This sample demonstrates various techniques to perform a large matrix multiplication where the matrix elements contain 8-bit integer data. The sample includes many different implementations: 1. The "naive" implementation is a very simple implementation. @@ -40,13 +40,13 @@ This sample will optionally use the following OpenCL extensions: |:--|:-:|:--| | `-p ` | 0 | Specify the index of the OpenCL platform to execute the sample on. | `-d ` | 0 | Specify the index of the OpenCL device in the platform to execute on the sample on. -| `--file ` | `matrix_kernels_bf16.cl` | Specify the name of the file with the OpenCL kernel source. +| `--file ` | `matrix_kernels_i8.cl` | Specify the name of the file with the OpenCL kernel source. | `--options ` | None | Specify optional program build options. | `--matrixsize ` | 512 | Specify the dimensions of the matrix. | `--iterations ` | 16 | Specify the number of iterations for performance testing. | `--validate` | n/a | Validate results for correctness. | `--zero` | n/a | Initialize all matrices to zero. -| `--identity` | n/a | Initialize all matrices to to one. +| `--identity` | n/a | Initialize all matrices to one. | `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. | `--emulate` | n/a | Do not use specialized matrix multiplication extensions. | `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. @@ -57,4 +57,4 @@ This sample will optionally use the following OpenCL extensions: By default, the source matrices are populated with random data. When validating results, it is recommended to use either "fixed" or "identity" data. -For best performance, use "zero" data". +For best performance, use "zero" data. diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp index 8bea9828..9f891b60 100644 --- a/samples/20_matrixexperiments-i8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -519,6 +519,7 @@ int main(int argc, char** argv) if (deviceIndex >= devices.size()) { printf("Requested device index is %d, but only %zu devices were found.\n", deviceIndex, devices.size()); + return -1; } cl::Device& device = devices[deviceIndex]; diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index 9aa6ffcb..f4b0bd4a 100644 --- a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -34,6 +34,7 @@ int4 activation(int4 i) return res; } +__attribute__((overloadable)) int8 activation(int8 i) { int8 res; @@ -153,7 +154,7 @@ int8 emu_sub_group_i8_i8_matrix_mad_k32(int8 a, int8 b, int8 acc) __attribute__((overloadable)) int emu_sub_group_i8_i8_matrix_mad_k32(short a, int8 b, int acc) { - float res = acc; + int res = acc; res = dp4(as_uint((short2)(sub_group_broadcast(a, 0), sub_group_broadcast(a, 1))), b.s0) + res; res = dp4(as_uint((short2)(sub_group_broadcast(a, 2), sub_group_broadcast(a, 3))), b.s1) + res; diff --git a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl index 6bec5f2f..58732493 100644 --- a/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_kernels_i8.cl @@ -36,7 +36,7 @@ kernel void i8_naive(global int* C, global char* A, global char* B, int K) // For all i8 kernels tK == 32: #define tK 32 -#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) +#if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) #if HAS_SG8 @@ -582,6 +582,6 @@ kernel void i8_dpas_blockread_vnni_m8_n16(global int* C, global char* A, global #endif // cl_intel_subgroup_2d_block_io -#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_required_subgroup_size) +#endif // defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) && defined(cl_intel_subgroups_char) && defined(cl_intel_required_subgroup_size) #undef tK diff --git a/samples/20_matrixexperiments-tf32/README.md b/samples/20_matrixexperiments-tf32/README.md index 62ffadf5..f54dd919 100644 --- a/samples/20_matrixexperiments-tf32/README.md +++ b/samples/20_matrixexperiments-tf32/README.md @@ -2,7 +2,7 @@ ## Sample Purpose -This sample demonstrates various techniques to perform a large matrix multiplcation where the matrix elements contain 32-bit `tf32` data. +This sample demonstrates various techniques to perform a large matrix multiplication where the matrix elements contain 32-bit `tf32` data. The sample includes many different implementations: 1. The "naive" implementation is a very simple implementation. @@ -44,7 +44,7 @@ This sample will optionally use the following OpenCL extensions: | `--iterations ` | 16 | Specify the number of iterations for performance testing. | `--validate` | n/a | Validate results for correctness. | `--zero` | n/a | Initialize all matrices to zero. -| `--identity` | n/a | Initialize all matrices to to one. +| `--identity` | n/a | Initialize all matrices to one. | `--fixed` | n/a | Initialize all matrices to values computed from the matrix row and column. | `--emulate` | n/a | Do not use specialized matrix multiplication extensions. | `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. @@ -55,4 +55,4 @@ This sample will optionally use the following OpenCL extensions: By default, the source matrices are populated with random data. When validating results, it is recommended to use either "fixed" or "identity" data. -For best performance, use "zero" data". +For best performance, use "zero" data. diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index fcfcc2fc..2dddc9c1 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -530,6 +530,7 @@ int main(int argc, char** argv) if (deviceIndex >= devices.size()) { printf("Requested device index is %d, but only %zu devices were found.\n", deviceIndex, devices.size()); + return -1; } cl::Device& device = devices[deviceIndex]; diff --git a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl index da7c1f8a..2944f65e 100644 --- a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl @@ -34,6 +34,7 @@ float4 activation(float4 f) return res; } +__attribute__((overloadable)) float8 activation(float8 f) { float8 res; diff --git a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl index e5fb52ee..f7f90dc5 100644 --- a/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_kernel_tiled_tf32.cl @@ -5,7 +5,7 @@ */ #if !defined(tK) -#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the elemement type, likely 16 or 32." +#error "tK is undefined! This should be defined as the K dimension of the matrix tiles, which is dependent on the element type, likely 16 or 32." #endif #if !defined(MM) From cdee85b9d3758f2fb3ffd478c0b66914f3c7b18a Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 2 Jun 2026 13:18:51 -0700 Subject: [PATCH 97/99] fix error calculation --- samples/20_matrixexperiments-bf16/main.cpp | 35 +++++++++++++++------- samples/20_matrixexperiments-tf32/main.cpp | 35 +++++++++++++++------- 2 files changed, 50 insertions(+), 20 deletions(-) diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index e90d691e..7db55e23 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -156,22 +156,37 @@ void check_results( const std::vector& C, const std::vector& C_ref) { - float err = 0.f; + const float absolute = 1e-4f; + + float maxErr = 0.f; + int errorCount = 0; + for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++) { auto index = m * N + n; - auto localErr = std::fabs(C[index] - C_ref[index]) / - std::max(std::fabs(C[index]), - std::fabs(C_ref[index])); - err = std::max(localErr, err); - if (localErr >= threshold) { - std::cerr << "Error at m = " << m << ", n = " << n - << ": (local error " << localErr << "): Wanted " - << C_ref[index] << ", got " << C[index] << std::endl; - return; + float got = static_cast(C[index]); + float want = static_cast(C_ref[index]); + float localErr = std::fabs(got - want); + float localThreshold = absolute + threshold * std::fabs(want); + + maxErr = std::max(localErr, maxErr); + if (localErr > localThreshold) { + if (errorCount < 1) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (abs error " << localErr << ", threshold " + << localThreshold << "): Wanted " << want + << ", got " << got << std::endl; + } + ++errorCount; } } } + + if (errorCount > 0) { + std::cerr << "FAILED: " << errorCount << " of " << M * N + << " elements exceeded tolerance. Max abs error: " + << maxErr << std::endl; + } } static float hw_time(cl::Event& event) diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index 2dddc9c1..a9a80fc6 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -154,22 +154,37 @@ void check_results( const std::vector& C, const std::vector& C_ref) { - float err = 0.f; + const float absolute = 1e-4f; + + float maxErr = 0.f; + int errorCount = 0; + for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++) { auto index = m * N + n; - auto localErr = std::fabs(C[index] - C_ref[index]) / - std::max(std::fabs(C[index]), - std::fabs(C_ref[index])); - err = std::max(localErr, err); - if (localErr >= threshold) { - std::cerr << "Error at m = " << m << ", n = " << n - << ": (local error " << localErr << "): Wanted " - << C_ref[index] << ", got " << C[index] << std::endl; - return; + float got = static_cast(C[index]); + float want = static_cast(C_ref[index]); + float localErr = std::fabs(got - want); + float localThreshold = absolute + threshold * std::fabs(want); + + maxErr = std::max(localErr, maxErr); + if (localErr > localThreshold) { + if (errorCount < 1) { + std::cerr << "Error at m = " << m << ", n = " << n + << ": (abs error " << localErr << ", threshold " + << localThreshold << "): Wanted " << want + << ", got " << got << std::endl; + } + ++errorCount; } } } + + if (errorCount > 0) { + std::cerr << "FAILED: " << errorCount << " of " << M * N + << " elements exceeded tolerance. Max abs error: " + << maxErr << std::endl; + } } static float hw_time(cl::Event& event) From e261d9c9ebdae03ca6e4959bfbc87a85021878a3 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 2 Jun 2026 22:45:48 -0700 Subject: [PATCH 98/99] a few more fixes --- samples/20_matrixexperiments-bf16/main.cpp | 18 ++++++++++++------ .../matrix_helpers_bf16.cl | 4 ++-- samples/20_matrixexperiments-i8/main.cpp | 19 ++++++++++++------- .../matrix_helpers_i8.cl | 4 ++-- samples/20_matrixexperiments-tf32/main.cpp | 13 +++++++------ .../matrix_helpers_tf32.cl | 4 ++-- 6 files changed, 37 insertions(+), 25 deletions(-) diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index 7db55e23..cd2810b6 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -68,18 +69,23 @@ std::string makeTestName( static size_t findMinSubGroupSize(cl::Device& device) { - auto s = device.getInfo(); - auto it = std::min_element(std::begin(s), std::end(s)); - if (it != std::end(s)) { - return *it; + if (checkDeviceForExtension(device, CL_INTEL_REQUIRED_SUBGROUP_SIZE_EXTENSION_NAME)) { + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } } return 0; } static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) { - auto s = device.getInfo(); - return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); + if (checkDeviceForExtension(device, CL_INTEL_REQUIRED_SUBGROUP_SIZE_EXTENSION_NAME)) { + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); + } + return false; } static void setRoundRobin(cl::Kernel& kernel) diff --git a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl index a792bee5..d9d42d5a 100644 --- a/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl +++ b/samples/20_matrixexperiments-bf16/matrix_helpers_bf16.cl @@ -61,8 +61,8 @@ float8 activation(float8 f) #ifndef __has_builtin #define __has_builtin(x) 0 #endif -#if __has_builtin(__builtin_expect) == 0 -#define __builtin_expect(x) +#if __has_builtin(__builtin_assume) == 0 +#define __builtin_assume(x) #endif #if defined(cl_intel_subgroups) && defined(cl_intel_subgroups_short) diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp index 9f891b60..be6d0948 100644 --- a/samples/20_matrixexperiments-i8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -10,7 +10,7 @@ #include #include -#include +#include #include #include #include @@ -68,18 +68,23 @@ std::string makeTestName( static size_t findMinSubGroupSize(cl::Device& device) { - auto s = device.getInfo(); - auto it = std::min_element(std::begin(s), std::end(s)); - if (it != std::end(s)) { - return *it; + if (checkDeviceForExtension(device, CL_INTEL_REQUIRED_SUBGROUP_SIZE_EXTENSION_NAME)) { + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } } return 0; } static bool supportsSubgroupSize(cl::Device& device, size_t subgroupSize) { - auto s = device.getInfo(); - return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); + if (checkDeviceForExtension(device, CL_INTEL_REQUIRED_SUBGROUP_SIZE_EXTENSION_NAME)) { + auto s = device.getInfo(); + return std::find(std::begin(s), std::end(s), subgroupSize) != std::end(s); + } + return false; } static void setRoundRobin(cl::Kernel& kernel) diff --git a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl index f4b0bd4a..ef83ef19 100644 --- a/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl +++ b/samples/20_matrixexperiments-i8/matrix_helpers_i8.cl @@ -52,8 +52,8 @@ int8 activation(int8 i) #ifndef __has_builtin #define __has_builtin(x) 0 #endif -#if __has_builtin(__builtin_expect) == 0 -#define __builtin_expect(x) +#if __has_builtin(__builtin_assume) == 0 +#define __builtin_assume(x) #endif #if defined(__opencl_c_integer_dot_product_input_4x8bit_packed) diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index a9a80fc6..d9bd0a94 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -10,8 +10,7 @@ #include #include -#include -#include +#include #include #include #include @@ -69,10 +68,12 @@ std::string makeTestName( static size_t findMinSubGroupSize(cl::Device& device) { - auto s = device.getInfo(); - auto it = std::min_element(std::begin(s), std::end(s)); - if (it != std::end(s)) { - return *it; + if (checkDeviceForExtension(device, CL_INTEL_REQUIRED_SUBGROUP_SIZE_EXTENSION_NAME)) { + auto s = device.getInfo(); + auto it = std::min_element(std::begin(s), std::end(s)); + if (it != std::end(s)) { + return *it; + } } return 0; } diff --git a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl index 2944f65e..a8e96a85 100644 --- a/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl +++ b/samples/20_matrixexperiments-tf32/matrix_helpers_tf32.cl @@ -52,8 +52,8 @@ float8 activation(float8 f) #ifndef __has_builtin #define __has_builtin(x) 0 #endif -#if __has_builtin(__builtin_expect) == 0 -#define __builtin_expect(x) +#if __has_builtin(__builtin_assume) == 0 +#define __builtin_assume(x) #endif #if defined(cl_intel_subgroups) From f0c05d0de30fb0c312770c5c0a1f0f7180ad9656 Mon Sep 17 00:00:00 2001 From: Ben Ashbaugh Date: Tue, 2 Jun 2026 23:07:49 -0700 Subject: [PATCH 99/99] final cleanup --- include/bfloat16.hpp | 5 +++++ include/util.hpp | 3 +++ samples/20_matrixexperiments-bf16/main.cpp | 4 +++- samples/20_matrixexperiments-i8/README.md | 1 - samples/20_matrixexperiments-i8/main.cpp | 7 +++---- samples/20_matrixexperiments-tf32/main.cpp | 4 +++- 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/include/bfloat16.hpp b/include/bfloat16.hpp index 6a36fe69..143588ea 100644 --- a/include/bfloat16.hpp +++ b/include/bfloat16.hpp @@ -1,3 +1,8 @@ +/* +// Copyright (c) 2024-2026 Ben Ashbaugh +// +// SPDX-License-Identifier: MIT +*/ #pragma once #include diff --git a/include/util.hpp b/include/util.hpp index 68f3014d..847ffb50 100644 --- a/include/util.hpp +++ b/include/util.hpp @@ -7,6 +7,9 @@ #include +#include +#include +#include #include #include #include diff --git a/samples/20_matrixexperiments-bf16/main.cpp b/samples/20_matrixexperiments-bf16/main.cpp index cd2810b6..cb48f0bc 100644 --- a/samples/20_matrixexperiments-bf16/main.cpp +++ b/samples/20_matrixexperiments-bf16/main.cpp @@ -10,10 +10,12 @@ #include #include +#include +#include #include +#include #include #include -#include #include #include "bfloat16.hpp" diff --git a/samples/20_matrixexperiments-i8/README.md b/samples/20_matrixexperiments-i8/README.md index 53ac92f5..40c0336f 100644 --- a/samples/20_matrixexperiments-i8/README.md +++ b/samples/20_matrixexperiments-i8/README.md @@ -52,7 +52,6 @@ This sample will optionally use the following OpenCL extensions: | `--wallclock` | n/a | Measure performance using wallclock time instead of event profiling. | `--skipinit` | n/a | Skip initialization of source matrices. | `--roundrobin` | n/a | Use round robin thread scheduling. -| `--threshold ` | 0.01 | Set the threshold used when validating results. | `--mask ` | ~0 | Set a mask to only run a subset of tests. By default, the source matrices are populated with random data. diff --git a/samples/20_matrixexperiments-i8/main.cpp b/samples/20_matrixexperiments-i8/main.cpp index be6d0948..98e0d8bd 100644 --- a/samples/20_matrixexperiments-i8/main.cpp +++ b/samples/20_matrixexperiments-i8/main.cpp @@ -10,10 +10,12 @@ #include #include +#include +#include #include +#include #include #include -#include #include #include "util.hpp" @@ -29,7 +31,6 @@ bool wallclock = false; bool skipinit = false; bool roundRobin = false; int testIterations = 16; -float threshold = 0.01f; std::string makeTestName( const std::string &func, @@ -160,7 +161,6 @@ void check_results( const std::vector& C, const std::vector& C_ref) { - float err = 0.f; for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++) { auto index = m * N + n; @@ -491,7 +491,6 @@ int main(int argc, char** argv) op.add("", "wallclock", "Measure Wallclock Time", &wallclock); op.add("", "skipinit", "Do Not Initialize Buffers", &skipinit); op.add("", "roundrobin", "Use Round Robin Scheduling", &roundRobin); - op.add>("", "threshold", "Local Error Threshold", threshold, &threshold); op.add, popl::Attribute::advanced>("", "mask", "Test Mask", mask, &mask); bool printUsage = false; try { diff --git a/samples/20_matrixexperiments-tf32/main.cpp b/samples/20_matrixexperiments-tf32/main.cpp index d9bd0a94..0b89c5b3 100644 --- a/samples/20_matrixexperiments-tf32/main.cpp +++ b/samples/20_matrixexperiments-tf32/main.cpp @@ -10,10 +10,12 @@ #include #include +#include +#include #include +#include #include #include -#include #include #include "util.hpp"