diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index dbc8fc89716..75f2e1a98a6 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -253,6 +253,67 @@ jobs: run: ccache -s shell: bash + gn-windows-x86_64: + runs-on: 'windows-2022-8core' + timeout-minutes: 30 + env: + DEPOT_TOOLS_WIN_TOOLCHAIN: 0 + steps: + - name: Setup ccache + uses: hendrikmuhs/ccache-action@v1.2 + with: + key: ${{ github.job }} + max-size: "500M" + save: ${{ inputs.update-caches }} + - name: Install Depot Tools + uses: newkdev/setup-depot-tools@v1.0.1 + - name: Write .gclient + shell: powershell + run: | + Set-Content -Path .gclient -Value " + solutions = [ + { 'name': 'XNNPACK', + 'url': 'https://github.com/google/XNNPACK', + 'deps_file': 'DEPS', + 'managed': False, + 'custom_deps': {}, + }, + ] + " + working-directory: ${{ github.workspace }} + - name: Setup build environment + shell: bash + run: | + echo "VCVARSALL=$(vswhere -products \* -latest -property installationPath)\\VC\\Auxiliary\\Build\\vcvarsall.bat" >> $GITHUB_ENV + - name: Sync to commit and run hooks + shell: powershell + run: gclient sync -vv --revision $env:GITHUB_SHA + - name: Write LASTCHANGE files + shell: powershell + run: | + $LastChangeContents = @( + "LASTCHANGE=$env:GITHUB_SHA-$env:GITHUB_REF_NAME" + "LASTCHANGE_YEAR=$(git log -1 --pretty=%Y)" + ) + + Set-Content -Path build/util/LASTCHANGE -Value $LastChangeContents + $UnixTime = git log -1 --pretty=%ct + Set-Content -Path build/util/LASTCHANGE.committime -Value $UnixTime + working-directory: ${{ github.workspace }}\XNNPACK + - name: Generate build files (x64, dchecks) + run: | + gn gen --check --args="is_debug=false symbol_level=0 dcheck_always_on=true cc_wrapper=\`"ccache\`" target_cpu=\`"x64\`"" out/x64.dchecks + working-directory: ${{ github.workspace }}\XNNPACK + shell: powershell + - name: Build all targets (x64 Release + debug checks) + run: | + autoninja -C out/x64.dchecks + working-directory: ${{ github.workspace }}\XNNPACK + - name: Run tests (x64) + run: | + python3 scripts/run-gn-tests.py out/x64.dchecks + working-directory: ${{ github.workspace }}\XNNPACK + cmake-macos-arm64: runs-on: macos-latest timeout-minutes: 60 diff --git a/BUILD.gn b/BUILD.gn index 26b9901b995..3536475988d 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -97,8 +97,10 @@ declare_args() { # Enables AVX2 support for x86 processors xnnpack_enable_avx2 = target_cpu == "x64" || target_cpu == "x86" - # Enables AVX512 support for x86 processors - xnnpack_enable_avx512 = target_cpu == "x64" || target_cpu == "x86" + # Enables AVX-512 support for x86 processors. Temporarily switched + # off on Windows because some AVX-512 assembly kernels contain unsupported + # syntax. TODO: crbug.com/523327327 - Re-enable this. + xnnpack_enable_avx512 = target_cpu == "x64" && !is_win # Enables VNNI extensions, which are separate from AVX512-VNNI xnnpack_enable_avx_vnni = target_cpu == "x64" || target_cpu == "x86" @@ -177,7 +179,6 @@ config("xnnpack_private_config") { "XNN_ENABLE_ARM_I8MM=1", "XNN_ENABLE_ASSEMBLY=1", "XNN_ENABLE_ARM_FP16_VECTOR=1", - "XNN_ENABLE_RNDNU16=1", ] if (xnnpack_enable_arm_kleidiai) { defines += [ "XNN_ENABLE_KLEIDIAI=1" ] @@ -918,6 +919,7 @@ xnnpack_source_set("scalar_microkernels") { deps = [ ":microkernel_defs", ":microkernel_headers", + "//third_party/fxdiv", ] sources = ALL_SCALAR_MICROKERNEL_SRCS } @@ -965,6 +967,7 @@ xnnpack_source_set("xnnpack") { ":scalar_microkernels", ":subgraph", ":table", + "//third_party/fxdiv", ] if (xnnpack_enable_arm_kleidiai) { deps += [ "//third_party/kleidiai" ] @@ -1025,7 +1028,7 @@ xnnpack_source_set("xnnpack") { if (xnnpack_enable_avx512) { deps += [ ":avx512_microkernels" ] } - if (current_cpu == "x64") { + if (!is_win) { sources += AMD64_ASM_MICROKERNEL_SRCS } sources += ALL_SSE_MICROKERNEL_SRCS diff --git a/DEPS b/DEPS index 5cc10311009..6e795463991 100644 --- a/DEPS +++ b/DEPS @@ -318,6 +318,13 @@ deps = { } hooks = [ + { + # Update the Windows toolchain if necessary. + 'name': 'win_toolchain', + 'pattern': '.', + 'condition': 'checkout_win', + 'action': ['python3', 'build/vs_toolchain.py', 'update', '--force'], + }, { # Update the Mac toolchain if necessary. 'name': 'mac_toolchain', diff --git a/bench/subgraph/attention.cc b/bench/subgraph/attention.cc index 31fd92e9dd8..bcc1283527d 100644 --- a/bench/subgraph/attention.cc +++ b/bench/subgraph/attention.cc @@ -47,7 +47,7 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { xnnpack::ReplicableRandomDevice rng; uint32_t v0 = XNN_INVALID_VALUE_ID; - std::array v0_dims = {{b, s, n, t}}; + std::array v0_dims = {{b, s, n, h}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, v0_dims.size(), v0_dims.data(), /*data=*/nullptr, 0, XNN_VALUE_FLAG_EXTERNAL_INPUT, &v0); @@ -126,15 +126,7 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { return nullptr; } - uint32_t v8 = XNN_INVALID_VALUE_ID; - std::array v8_dims = {{b, n, t, s}}; - status = xnn_define_tensor_value( - subgraph.get(), xnn_datatype_fp32, v8_dims.size(), v8_dims.data(), - /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v8); - if (status != xnn_status_success) { - std::cerr << "failed to create tensor v8" << std::endl; - return nullptr; - } + uint32_t v9 = XNN_INVALID_VALUE_ID; std::array v9_dims = {{b, n, t, s}}; @@ -157,7 +149,7 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { } uint32_t v11 = XNN_INVALID_VALUE_ID; - std::array v11_dims = {{b, n, t, t}}; + std::array v11_dims = {{b, n, t, h}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, v11_dims.size(), v11_dims.data(), /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &v11); @@ -167,7 +159,7 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { } uint32_t v12 = XNN_INVALID_VALUE_ID; - std::array v12_dims = {{b, t, n, t}}; + std::array v12_dims = {{b, t, n, h}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, v12_dims.size(), v12_dims.data(), /*data=*/nullptr, 3, XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &v12); @@ -286,14 +278,7 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { return nullptr; } - status = xnn_define_unary(subgraph.get(), xnn_unary_tanh, - /*params=*/nullptr, v7, v8, 0); - if (status != xnn_status_success) { - std::cerr << "failed to create node #5" << std::endl; - return nullptr; - } - - status = xnn_define_softmax(subgraph.get(), v8, v9, + status = xnn_define_softmax(subgraph.get(), v7, v9, /*flags=*/0); if (status != xnn_status_success) { std::cerr << "failed to create node #6" << std::endl; @@ -335,10 +320,11 @@ xnn_subgraph_t FP32Attention(size_t b, size_t t, size_t h, size_t n, size_t s) { } xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, - size_t embedding_dim, size_t num_heads, - size_t head_dim, QD8AttentionWeights& weights) { + size_t head_dim, size_t num_heads, + size_t key_len, QD8AttentionWeights& weights) { + size_t embedding_dim = num_heads * head_dim; xnn_status status; - auto subgraph = xnnpack::CreateUniqueSubgraph(/*num_external_values=*/2, 0); + auto subgraph = xnnpack::CreateUniqueSubgraph(/*num_external_values=*/3, 0); if (!subgraph) { std::cerr << "failed to create subgrpah" << std::endl; return nullptr; @@ -352,33 +338,63 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, std::bind(std::uniform_int_distribution(-127, 127), std::ref(rng)); // External inputs and outputs. - uint32_t input_id = XNN_INVALID_VALUE_ID; - std::array input_dims = {{batch_size, seq_len, embedding_dim}}; + uint32_t query_input_id = XNN_INVALID_VALUE_ID; + std::array query_input_dims = {{batch_size, seq_len, embedding_dim}}; status = xnn_define_tensor_value( - subgraph.get(), xnn_datatype_fp32, input_dims.size(), input_dims.data(), + subgraph.get(), xnn_datatype_fp32, query_input_dims.size(), query_input_dims.data(), /*data=*/nullptr, /*external_id=*/0, - /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &input_id); + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &query_input_id); if (status != xnn_status_success) { - std::cerr << "failed to create input tensor " << std::endl; + std::cerr << "failed to create query input tensor " << std::endl; return nullptr; } - uint32_t quantized_input_id = XNN_INVALID_VALUE_ID; + uint32_t kv_input_id = XNN_INVALID_VALUE_ID; + std::array kv_input_dims = {{batch_size, key_len, embedding_dim}}; + status = xnn_define_tensor_value( + subgraph.get(), xnn_datatype_fp32, kv_input_dims.size(), kv_input_dims.data(), + /*data=*/nullptr, /*external_id=*/1, + /*flags=*/XNN_VALUE_FLAG_EXTERNAL_INPUT, &kv_input_id); + if (status != xnn_status_success) { + std::cerr << "failed to create kv input tensor " << std::endl; + return nullptr; + } + + uint32_t quantized_query_input_id = XNN_INVALID_VALUE_ID; status = xnn_define_dynamically_quantized_tensor_value( - subgraph.get(), xnn_datatype_qdint8, input_dims.size(), - /*num_non_batch_dims=*/1, input_dims.data(), XNN_INVALID_VALUE_ID, - /*flags=*/0, &quantized_input_id); + subgraph.get(), xnn_datatype_qdint8, query_input_dims.size(), + /*num_non_batch_dims=*/1, query_input_dims.data(), XNN_INVALID_VALUE_ID, + /*flags=*/0, &quantized_query_input_id); if (status != xnn_status_success) { - std::cerr << "failed to create dynamically quantized input tensor " + std::cerr << "failed to create dynamically quantized query input tensor " << std::endl; return nullptr; } status = xnn_define_unary(subgraph.get(), xnn_unary_convert, /*params=*/nullptr, - input_id, quantized_input_id, /*flags=*/0); + query_input_id, quantized_query_input_id, /*flags=*/0); if (status != xnn_status_success) { - std::cerr << "failed to create create convert " << std::endl; + std::cerr << "failed to create convert for query input " << std::endl; + return nullptr; + } + + uint32_t quantized_kv_input_id = XNN_INVALID_VALUE_ID; + status = xnn_define_dynamically_quantized_tensor_value( + subgraph.get(), xnn_datatype_qdint8, kv_input_dims.size(), + /*num_non_batch_dims=*/1, kv_input_dims.data(), XNN_INVALID_VALUE_ID, + /*flags=*/0, &quantized_kv_input_id); + if (status != xnn_status_success) { + std::cerr << "failed to create dynamically quantized kv input tensor " + << std::endl; + return nullptr; + } + + status = + xnn_define_unary(subgraph.get(), xnn_unary_convert, /*params=*/nullptr, + kv_input_id, quantized_kv_input_id, /*flags=*/0); + if (status != xnn_status_success) { + std::cerr << "failed to create convert for kv input " << std::endl; return nullptr; } @@ -386,7 +402,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, std::array output_dims = {{batch_size, seq_len, embedding_dim}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, output_dims.size(), output_dims.data(), - /*data=*/nullptr, /*external_id=*/1, + /*data=*/nullptr, /*external_id=*/2, /*flags=*/XNN_VALUE_FLAG_EXTERNAL_OUTPUT, &output_id); if (status != xnn_status_success) { std::cerr << "failed to create output tensor " << std::endl; @@ -458,7 +474,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, } uint32_t key_proj_id = XNN_INVALID_VALUE_ID; - std::array key_proj_dims = {{seq_len, head_dim}}; + std::array key_proj_dims = {{batch_size, key_len, head_dim}}; status = xnn_define_tensor_value(subgraph.get(), xnn_datatype_fp32, key_proj_dims.size(), key_proj_dims.data(), /*data=*/nullptr, XNN_INVALID_VALUE_ID, @@ -469,7 +485,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, } uint32_t value_proj_id = XNN_INVALID_VALUE_ID; - std::array value_proj_dims = {{seq_len, head_dim}}; + std::array value_proj_dims = {{batch_size, key_len, head_dim}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, value_proj_dims.size(), value_proj_dims.data(), @@ -479,27 +495,39 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, return nullptr; } + uint32_t value_reshaped_id = XNN_INVALID_VALUE_ID; + std::array value_reshaped_dims = { + {batch_size, 1, key_len, head_dim}}; + status = xnn_define_tensor_value( + subgraph.get(), xnn_datatype_fp32, value_reshaped_dims.size(), + value_reshaped_dims.data(), + /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &value_reshaped_id); + if (status != xnn_status_success) { + std::cerr << "failed to create tensor value reshaped" << std::endl; + return nullptr; + } + const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); status = xnn_define_fully_connected( - subgraph.get(), output_min, output_max, quantized_input_id, query_id, + subgraph.get(), output_min, output_max, quantized_query_input_id, query_id, XNN_INVALID_VALUE_ID, query_proj_id, /*flags=*/0); if (status != xnn_status_success) { - std::cerr << "failed to create FC node" << std::endl; + std::cerr << "failed to create FC node for query" << std::endl; return nullptr; } status = xnn_define_fully_connected( - subgraph.get(), output_min, output_max, quantized_input_id, key_id, + subgraph.get(), output_min, output_max, quantized_kv_input_id, key_id, XNN_INVALID_VALUE_ID, key_proj_id, /*flags=*/0); if (status != xnn_status_success) { - std::cerr << "failed to create FC node" << std::endl; + std::cerr << "failed to create FC node for key" << std::endl; return nullptr; } status = xnn_define_fully_connected( - subgraph.get(), output_min, output_max, quantized_input_id, value_id, + subgraph.get(), output_min, output_max, quantized_kv_input_id, value_id, XNN_INVALID_VALUE_ID, value_proj_id, /*flags=*/0); if (status != xnn_status_success) { std::cerr << "failed to create FC node" << std::endl; @@ -528,7 +556,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, uint32_t key_reshaped_id = XNN_INVALID_VALUE_ID; std::array key_reshaped_dims = { - {batch_size, 1, seq_len, head_dim}}; + {batch_size, 1, key_len, head_dim}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, key_reshaped_dims.size(), key_reshaped_dims.data(), @@ -546,9 +574,18 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, return nullptr; } + status = + xnn_define_static_reshape(subgraph.get(), value_reshaped_dims.size(), + value_reshaped_dims.data(), value_proj_id, + value_reshaped_id, /*flags=*/0); + if (status != xnn_status_success) { + std::cerr << "failed to reshape value_proj" << std::endl; + return nullptr; + } + uint32_t logits_id = XNN_INVALID_VALUE_ID; std::array logits_dims = { - {batch_size, seq_len, num_heads, seq_len}}; + {batch_size, seq_len, num_heads, key_len}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, logits_dims.size(), logits_dims.data(), /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &logits_id); @@ -567,7 +604,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, uint32_t probs_id = XNN_INVALID_VALUE_ID; std::array probs_dims = { - {batch_size, seq_len, num_heads, seq_len}}; + {batch_size, seq_len, num_heads, key_len}}; status = xnn_define_tensor_value( subgraph.get(), xnn_datatype_fp32, probs_dims.size(), probs_dims.data(), /*data=*/nullptr, XNN_INVALID_VALUE_ID, /*flags=*/0, &probs_id); @@ -595,7 +632,7 @@ xnn_subgraph_t QD8Attention(size_t batch_size, size_t seq_len, } status = xnn_define_batch_matrix_multiply( - subgraph.get(), probs_id, value_proj_id, outcome_id, /*flags=*/0); + subgraph.get(), probs_id, value_reshaped_id, outcome_id, /*flags=*/0); if (status != xnn_status_success) { std::cerr << "failed to create batch matrix multiply" << std::endl; return nullptr; @@ -707,15 +744,10 @@ static void QD8Attention(benchmark::State& state) { static void AttentionArguments(benchmark::Benchmark* b) { b->ArgNames({"T", "H", "N", "S"}); - b->Args({16, 25, 24, 4}); - b->Args({1536, 128, 12, 18}); - b->Args({1024, 256, 4, 46}); - b->Args({1792, 256, 8, 36}); - b->Args({1536, 256, 6, 22}); - b->Args({2048, 256, 8, 18}); - b->Args({3072, 256, 16, 28}); - b->Args({2304, 256, 8, 26}); - b->Args({2048, 64, 32, 24}); + b->Args({64, 64, 32, 64}); + b->Args({256, 64, 32, 256}); + b->Args({1024, 64, 32, 1024}); + b->Args({4096, 64, 32, 4096}); } BENCHMARK(FP32Attention) diff --git a/litert/tensor/arithmetic.h b/litert/tensor/arithmetic.h index b3de14f6ccf..c9033b0ebeb 100644 --- a/litert/tensor/arithmetic.h +++ b/litert/tensor/arithmetic.h @@ -899,6 +899,15 @@ Tensor BatchMatMul( std::vector x_shape = x_info.shape; std::vector y_shape = y_info.shape; + if (x_shape.size() < 2 || y_shape.size() < 2) { + std::string x_shape_str = absl::StrJoin(x_shape, ","); + std::string y_shape_str = absl::StrJoin(y_shape, ","); + return Tensor( + graph::ErrorTensor(absl::InvalidArgumentError(absl::StrCat( + "Input tensors for BatchMatMul must have rank >= 2. x_name: ", + x.GetName(), " y_name: ", y.GetName(), " x_shape: ", x_shape_str, + " y_shape: ", y_shape_str)))); + } if (adj_x) { std::swap(x_shape[x_shape.size() - 2], x_shape[x_shape.size() - 1]); } @@ -916,15 +925,24 @@ Tensor BatchMatMul( " y_shape: ", y_shape_str, " adj_x: ", adj_x, " adj_y: ", adj_y)))); } - // Batch dimensions should be broadcastable. - if (x_shape.size() != y_shape.size()) { - // For now, we only support same rank BatchMatMul. - // TODO(piyu): Support different ranks. - } - output_info.shape.reserve(x_shape.size()); - for (size_t i = 0; i < x_shape.size() - 2; ++i) { - const auto x_dim = x_shape[i]; - const auto y_dim = y_shape[i]; + // Compute multi-rank batch broadcasting for the outer dimensions, aligning + // dimensions right-to-left and treating missing outer dims as 1. + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + size_t max_rank = std::max(x_rank, y_rank); + + output_info.shape.resize(max_rank); + output_info.shape[max_rank - 2] = x_shape[x_rank - 2]; + output_info.shape[max_rank - 1] = y_shape[y_rank - 1]; + + for (size_t i = 1; i <= max_rank - 2; ++i) { + int out_idx = max_rank - 2 - i; + int x_idx = static_cast(x_rank - 2) - static_cast(i); + int y_idx = static_cast(y_rank - 2) - static_cast(i); + + int x_dim = (x_idx >= 0) ? x_shape[x_idx] : 1; + int y_dim = (y_idx >= 0) ? y_shape[y_idx] : 1; + if (x_dim != y_dim && x_dim != 1 && y_dim != 1) { return Tensor(graph::ErrorTensor(absl::InvalidArgumentError( absl::StrCat("The batch dimensions of the input tensors must be " @@ -932,10 +950,8 @@ Tensor BatchMatMul( absl::StrJoin(x_info.shape, ","), " y_shape: ", absl::StrJoin(y_info.shape, ","))))); } - output_info.shape.push_back(std::max(x_dim, y_dim)); + output_info.shape[out_idx] = std::max(x_dim, y_dim); } - output_info.shape.push_back(x_shape[x_shape.size() - 2]); - output_info.shape.push_back(y_shape[y_shape.size() - 1]); output_info.type = x_info.type; @@ -948,11 +964,15 @@ Tensor FullyConnected( Tensor input, Tensor weights, std::optional> bias, FusedActivation activation = kActNone, bool keep_num_dims = true, + bool asymmetric_quantize_inputs = false, + FullyConnectedWeightsFormat weights_format = kWeightsFormatDefault, source_location loc = source_location::current()) { auto op = std::make_shared(); RegisterMixins(op); op->activation = activation; op->keep_num_dims = keep_num_dims; + op->asymmetric_quantize_inputs = asymmetric_quantize_inputs; + op->weights_format = weights_format; AddInputs(op, input, weights); if (bias.has_value()) { AddInputs(op, bias.value()); @@ -965,7 +985,11 @@ Tensor FullyConnected( output_info.shape = input_info.shape; output_info.shape.back() = weights_info.shape[0]; } else { - output_info.shape = {input_info.shape[0], weights_info.shape[0]}; + int batch = 1; + for (size_t i = 0; i < input_info.shape.size() - 1; ++i) { + batch *= input_info.shape[i]; + } + output_info.shape = {batch, weights_info.shape[0]}; } output_info.type = input_info.type; @@ -977,19 +1001,25 @@ template Tensor FullyConnected( Tensor input, Tensor weights, Tensor bias, FusedActivation activation = kActNone, bool keep_num_dims = true, + bool asymmetric_quantize_inputs = false, + FullyConnectedWeightsFormat weights_format = kWeightsFormatDefault, source_location loc = source_location::current()) { return FullyConnected(input, weights, std::optional(std::move(bias)), - activation, keep_num_dims, loc); + activation, keep_num_dims, asymmetric_quantize_inputs, + weights_format, loc); } template Tensor FullyConnected( Tensor input, Tensor weights, FusedActivation activation = kActNone, bool keep_num_dims = true, + bool asymmetric_quantize_inputs = false, + FullyConnectedWeightsFormat weights_format = kWeightsFormatDefault, source_location loc = source_location::current()) { return FullyConnected(input, weights, /*bias=*/std::optional>(std::nullopt), - activation, keep_num_dims, loc); + activation, keep_num_dims, asymmetric_quantize_inputs, + weights_format, loc); } template diff --git a/litert/tensor/arithmetic_graph.h b/litert/tensor/arithmetic_graph.h index e23b0827db2..b69fc46ae72 100644 --- a/litert/tensor/arithmetic_graph.h +++ b/litert/tensor/arithmetic_graph.h @@ -36,6 +36,12 @@ enum FusedActivation { kActSigmoid, }; +// Possible weights formats for FullyConnected. +enum FullyConnectedWeightsFormat { + kWeightsFormatDefault = 0, + kWeightsFormatShuffled4x16Int8 = 1, +}; + // Possible padding types. enum Padding { kPaddingSame = 0, @@ -276,6 +282,9 @@ struct BatchMatMulOperation : BatchMatMulOperationData, Operation { struct FullyConnectedOperationData { litert::tensor::FusedActivation activation; bool keep_num_dims; + bool asymmetric_quantize_inputs = false; + litert::tensor::FullyConnectedWeightsFormat weights_format = + litert::tensor::kWeightsFormatDefault; }; struct FullyConnectedOperation : FullyConnectedOperationData, Operation { diff --git a/litert/tensor/tensor_test.cc b/litert/tensor/tensor_test.cc index 54bdaf40544..b18b0773c85 100644 --- a/litert/tensor/tensor_test.cc +++ b/litert/tensor/tensor_test.cc @@ -178,7 +178,7 @@ TEST(TensorTest, FullyConnectedFlatten) { Tensor output = FullyConnected(input, weights, bias, kActNone, /*keep_num_dims=*/false); LRT_TENSOR_ASSERT_OK_AND_ASSIGN(const auto& output_info, GetInfo(output)); - EXPECT_THAT(output_info.shape, ElementsAre(2, 5)); + EXPECT_THAT(output_info.shape, ElementsAre(12, 5)); } TEST(TensorTest, SetQuantizationWorks) { diff --git a/scripts/run-gn-tests.py b/scripts/run-gn-tests.py index e0e337f63db..b290d312c98 100644 --- a/scripts/run-gn-tests.py +++ b/scripts/run-gn-tests.py @@ -15,6 +15,7 @@ import datetime import glob import os +import platform import sys # Add tests that require sharding here. @@ -172,7 +173,10 @@ async def main() -> None: ) # Pick up the executables - must be named in this way to work + # If we're on Windows, the executable will have a .exe extension. test_suites = list(sorted(glob.glob(args.out_dir + '/xnnpack_*_test'))) + if platform.system() == "Windows": + test_suites = list(sorted(glob.glob(args.out_dir + '/xnnpack_*_test.exe'))) print(f'Discovered {len(test_suites)} test suites...') # Create the list of tests to run, sharding the long ones. diff --git a/src/f16-vcos/f16-vcos.inc b/src/f16-vcos/f16-vcos.inc index 0fa6f9448db..295836edc9f 100644 --- a/src/f16-vcos/f16-vcos.inc +++ b/src/f16-vcos/f16-vcos.inc @@ -19,11 +19,11 @@ XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vcos_ukernel__neonfp16arith_ra XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vcos_ukernel__neonfp16arith_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) #endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) -#if XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_WASMRELAXEDSIMDFP16 XNN_UKERNEL(xnn_arch_none, xnn_f16_vcos_ukernel__wasmrelaxedsimd_rational_3_2_div_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_none, xnn_f16_vcos_ukernel__wasmrelaxedsimd_rational_3_2_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_none, xnn_f16_vcos_ukernel__wasmrelaxedsimd_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) -#endif // XNN_ARCH_WASMRELAXEDSIMD +#endif // XNN_ARCH_WASMRELAXEDSIMDFP16 #if XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64) XNN_UKERNEL(xnn_arch_x86_avx512fp16, xnn_f16_vcos_ukernel__avx512fp16_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) diff --git a/src/f16-vlog/f16-vlog.inc b/src/f16-vlog/f16-vlog.inc index 75ffabbd20d..34b4a173adf 100644 --- a/src/f16-vlog/f16-vlog.inc +++ b/src/f16-vlog/f16-vlog.inc @@ -14,7 +14,7 @@ XNN_UKERNEL(xnn_arch_none, xnn_f16_f32acc_vlog_ukernel__scalar_rational_1_3_div_ XNN_UKERNEL(xnn_arch_none, xnn_f16_f32acc_vlog_ukernel__scalar_rational_1_3_div_u2, 2, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_none, xnn_f16_f32acc_vlog_ukernel__scalar_rational_1_3_div_u4, 4, false, xnn_float16, struct xnn_f16_default_params, NULL) -#if XNN_ARCH_ARM || XNN_ARCH_ARM64 +#if XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) // ARM NEON XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_rational_1_3_div_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_rational_1_3_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) @@ -22,7 +22,9 @@ XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_ra XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_rational_1_3_nr_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_rational_1_3_nr_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vlog_ukernel__neonfp16arith_rational_1_3_nr_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) +#endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) +#if XNN_ARCH_ARM || XNN_ARCH_ARM64 XNN_UKERNEL(xnn_arch_arm_neon_fp16, xnn_f16_f32acc_vlog_ukernel__neonfp16_rational_1_3_div_u4, 4, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_arm_neon_fp16, xnn_f16_f32acc_vlog_ukernel__neonfp16_rational_1_3_div_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_arm_neon_fp16, xnn_f16_f32acc_vlog_ukernel__neonfp16_rational_1_3_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) @@ -31,7 +33,7 @@ XNN_UKERNEL(xnn_arch_arm_neon_fp16, xnn_f16_f32acc_vlog_ukernel__neonfp16_ration XNN_UKERNEL(xnn_arch_arm_neon_fp16, xnn_f16_f32acc_vlog_ukernel__neonfp16_rational_1_3_nr_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 -#if XNN_ARCH_X86 || XNN_ARCH_X86_64 +#if XNN_ENABLE_F16C && (XNN_ARCH_X86 || XNN_ARCH_X86_64) // X86 XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_div_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) @@ -39,16 +41,16 @@ XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_di XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_nr_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_nr_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_f16c, xnn_f16_f32acc_vlog_ukernel__f16c_rational_1_3_nr_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) +#endif // XNN_ENABLE_F16C && (XNN_ARCH_X86 || XNN_ARCH_X86_64) -#if XNN_ENABLE_AVX512F +#if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_div_u64, 64, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_nr_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_nr_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_x86_avx512f, xnn_f16_f32acc_vlog_ukernel__avx512f_rational_1_3_nr_u64, 64, false, xnn_float16, struct xnn_f16_default_params, NULL) -#endif // XNN_ENABLE_AVX512F -#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 +#endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) #if XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64) XNN_UKERNEL(xnn_arch_x86_avx512fp16, xnn_f16_vlog_ukernel__avx512fp16_rational_1_3_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) diff --git a/src/f16-vsin/f16-vsin.inc b/src/f16-vsin/f16-vsin.inc index ab4fa2a3c0d..8d75151933b 100644 --- a/src/f16-vsin/f16-vsin.inc +++ b/src/f16-vsin/f16-vsin.inc @@ -19,11 +19,11 @@ XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vsin_ukernel__neonfp16arith_ra XNN_UKERNEL(xnn_arch_arm_neon_fp16_arith, xnn_f16_vsin_ukernel__neonfp16arith_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) #endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) -#if XNN_ARCH_WASMRELAXEDSIMD +#if XNN_ARCH_WASMRELAXEDSIMDFP16 XNN_UKERNEL(xnn_arch_none, xnn_f16_vsin_ukernel__wasmrelaxedsimd_rational_3_2_div_u8, 8, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_none, xnn_f16_vsin_ukernel__wasmrelaxedsimd_rational_3_2_div_u16, 16, false, xnn_float16, struct xnn_f16_default_params, NULL) XNN_UKERNEL(xnn_arch_none, xnn_f16_vsin_ukernel__wasmrelaxedsimd_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) -#endif // XNN_ARCH_WASMRELAXEDSIMD +#endif // XNN_ARCH_WASMRELAXEDSIMDFP16 #if XNN_ENABLE_AVX512FP16 && (XNN_ARCH_X86 || XNN_ARCH_X86_64) XNN_UKERNEL(xnn_arch_x86_avx512fp16, xnn_f16_vsin_ukernel__avx512fp16_rational_3_2_div_u32, 32, false, xnn_float16, struct xnn_f16_default_params, NULL) diff --git a/ynnpack/include/ynnpack.h b/ynnpack/include/ynnpack.h index ce39365e186..dd68a93f857 100644 --- a/ynnpack/include/ynnpack.h +++ b/ynnpack/include/ynnpack.h @@ -260,11 +260,14 @@ enum ynn_status ynn_define_binary(ynn_subgraph_t subgraph, uint32_t input_a_id, uint32_t input_b_id, uint32_t* output_id, uint32_t flags); -// Defines a lookup table operation. `output_id` will have the same shape as -// `input_id`. -enum ynn_status ynn_define_lut(ynn_subgraph_t subgraph, uint32_t input_id, - uint32_t lut_id, uint32_t* output_id, - uint32_t flags); +// Defines a gather operation. This computes: +// `output[...i, j, ...k] = input[...i, index[...i, j, ...k], ...k]`, where `j` +// is the `axis` dimension. This operation supports broadcasting of the input, +// but not the index. Such broadcasting should be performed with a subsequent +// broadcast operation. +enum ynn_status ynn_define_gather(ynn_subgraph_t subgraph, int32_t axis, + uint32_t input_id, uint32_t index_id, + uint32_t* output_id, uint32_t flags); // Changes the shape of `input_id` to have the shape `new_dims`, by broadcasting // extent 1 dimensions. If `new_dims[d]` is zero, dimension `d` is passed diff --git a/ynnpack/subgraph/BUILD b/ynnpack/subgraph/BUILD index 9d58b246525..68fd3f201e1 100644 --- a/ynnpack/subgraph/BUILD +++ b/ynnpack/subgraph/BUILD @@ -78,6 +78,8 @@ cc_library( "fusion_lut.h", "fusion_types.cc", "fusion_types.h", + "gather.cc", + "gather.h", "get_tensor_shape.cc", "iota.cc", "iota.h", diff --git a/ynnpack/subgraph/elementwise.cc b/ynnpack/subgraph/elementwise.cc index 4dc9b273e8d..8491dd5eb6a 100644 --- a/ynnpack/subgraph/elementwise.cc +++ b/ynnpack/subgraph/elementwise.cc @@ -8,7 +8,6 @@ #include #include #include -#include #include "ynnpack/base/log.h" #include "ynnpack/base/to_string.h" @@ -16,7 +15,6 @@ #include "ynnpack/include/ynnpack.h" #include "ynnpack/kernels/binary/binary.h" #include "ynnpack/kernels/dequantize_dot/dequantize_dot.h" -#include "ynnpack/kernels/lut/lut.h" #include "ynnpack/kernels/ternary/ternary.h" #include "ynnpack/kernels/unary/unary.h" #include "ynnpack/subgraph/runtime.h" @@ -62,30 +60,6 @@ auto make_unary_elementwise_impl(unary_kernel_fn kernel, unary_params params) { }; } -// Call a lut kernel. -auto make_lut_impl(lut_kernel_fn kernel) { - return [kernel](slinky::raw_buffer a, slinky::raw_buffer lut, - slinky::raw_buffer x) -> slinky::index_t { - slinky::dim a_dims[1], x_dims[1]; - - if (!fuse_and_slice_leading_dims<1>(&x_dims[0], x, &a_dims[0], a)) { - return 0; - } - - // We don't support broadcasting of `a` here in the innermost - // dimension (and it would waste computation). - assert(is_contiguous(a_dims[0], a.elem_size)); - assert(is_contiguous(x_dims[0], x.elem_size)); - - const slinky::index_t x_n_extent = x_dims[0].extent(); - - slinky::for_each_element( - [=](void* x, const void* a) { kernel(x_n_extent, a, lut.base, x); }, x, - a); - return 0; - }; -} - // Call a binary kernel. auto make_binary_elementwise_impl(binary_kernel_fn kernel) { return [kernel](slinky::raw_buffer a, slinky::raw_buffer b, @@ -115,17 +89,6 @@ auto make_binary_elementwise_impl(binary_kernel_fn kernel) { }; } -int compute_allow_in_place(const ynn_node& node, const ynn_subgraph& subgraph) { - assert(node.outputs.size() == 1); - int result = 0; - for (int i = 0; i < node.inputs.size(); ++i) { - if (allow_in_place(node.inputs[i], node.outputs[0], subgraph)) { - result |= 1 << i; - } - } - return result; -} - auto make_ternary_elementwise_impl(ternary_kernel_fn kernel) { return [kernel](slinky::raw_buffer a, slinky::raw_buffer b, slinky::raw_buffer c, @@ -270,40 +233,6 @@ ynn_status create_unary(const ynn_node& node, ynn_runtime& runtime, return ynn_status_success; } -ynn_status create_lut(const ynn_node& node, ynn_runtime& runtime, - lut_kernel_fn kernel) { - assert(node.inputs.size() == 2); - assert(node.outputs.size() == 1); - - const ynn_runtime_value& a = runtime.value(node.inputs[0]); - const ynn_runtime_value& lut = runtime.value(node.inputs[1]); - ynn_runtime_value& x = runtime.value(node.outputs[0]); - - x.make_buffer(runtime); - std::vector dims = runtime.globals.make_dims(x.rank()); - slinky::box_expr bounds = make_elementwise_bounds(dims, a.physical_extents()); - - slinky::box_expr lut_bounds = { - slinky::interval_expr(0, 1 << type_size_bytes(a.type))}; - - slinky::call_stmt::attributes attrs; - attrs.name = "lut"; - attrs.allow_in_place = compute_allow_in_place(node, *runtime.subgraph); - - auto func = slinky::func::make( - make_lut_impl(kernel), - {{a.buffer, std::move(bounds)}, {lut.buffer, std::move(lut_bounds)}}, - {{x.buffer, dims}}, std::move(attrs)); - - auto sched = - runtime.make_schedule(dims, x.physical_extents(), x.buffer->elem_size()); - func.user_data() = sched.get(); - runtime.scheduling_info_storage.push_back(std::move(sched)); - runtime.funcs.push_back(std::move(func)); - - return ynn_status_success; -} - ynn_status create_binary(const ynn_node& node, ynn_runtime& runtime, binary_kernel_fn kernel) { assert(node.inputs.size() == 2); @@ -441,31 +370,6 @@ void define_ternary(ynn_subgraph& subgraph, ynn_node& node, uint32_t input_a_id, }; } -void define_lut(ynn_subgraph& subgraph, ynn_node& node, uint32_t input_id, - uint32_t lut_id, uint32_t& output_id) { - const ynn_value& a = subgraph.value(input_id); - ynn_value& x = subgraph.get_output_value(&output_id, a); - - // Find kernel. - lut_kernel_fn kernel = get_lut_kernel(a.type, x.type); - assert(kernel); - - node.inputs = {input_id, lut_id}; - node.outputs = {output_id}; - node.op = ynn_node::lut{}; - - // Propagate shape from A only. - x.extents.resize(a.rank()); - for (size_t d = 0; d < x.rank(); ++d) { - subgraph.infer_elementwise_shape(node, /*input_idx=*/0, /*output_idx=*/0, - /*input_dim=*/d, /*output_dim=*/d); - } - - node.create = [kernel](const ynn_node& node, ynn_runtime& runtime) { - return create_lut(node, runtime, kernel); - }; -} - bool define_dequantize_dot(ynn_subgraph& subgraph, ynn_node& node, ynn_type output_type, uint32_t dot_id, uint32_t a_offset_id, uint32_t b_offset_id, @@ -984,38 +888,6 @@ ynn_status ynn_define_binary(ynn_subgraph_t subgraph, ynn_binary_operator op, return ynn_status_success; } -ynn_status ynn_define_lut(ynn_subgraph_t subgraph, uint32_t input_id, - uint32_t lut_id, uint32_t* output_id, - uint32_t flags) { - YNN_RETURN_IF_ERROR(validate_subgraph("lut", subgraph)); - YNN_RETURN_IF_ERROR( - validate_input_tensor("lut", subgraph, "input_id", input_id)); - YNN_RETURN_IF_ERROR(validate_input_tensor("lut", subgraph, "lut_id", lut_id)); - YNN_RETURN_IF_ERROR( - validate_output_tensor("lut", subgraph, "output_id", output_id)); - - const ynn_value& a = subgraph->value(input_id); - const ynn_value& lut = subgraph->value(lut_id); - - if (!ynn::type_is_integral(a.type)) { - YNN_LOG_ERROR() << "For node `lut`, input must be integral, got " << a.type; - return ynn_status_invalid_parameter; - } - if (!ynn::type_is_integral(lut.type)) { - YNN_LOG_ERROR() << "For node `lut`, lut must be integral, got " << lut.type; - return ynn_status_invalid_parameter; - } - if (lut.rank() != 1) { - YNN_LOG_ERROR() << "For node `lut`, lut must be 1D, got " << lut.rank(); - return ynn_status_invalid_parameter; - } - - ynn_node node; - define_lut(*subgraph, node, input_id, lut_id, *output_id); - subgraph->add_node(std::move(node)); - return ynn_status_success; -} - } // extern "C" } // namespace ynn diff --git a/ynnpack/subgraph/elementwise.h b/ynnpack/subgraph/elementwise.h index 49a98485bf9..1bf9a013c64 100644 --- a/ynnpack/subgraph/elementwise.h +++ b/ynnpack/subgraph/elementwise.h @@ -27,8 +27,6 @@ void define_ternary(ynn_subgraph& subgraph, ynn_node& node, uint32_t input_a_id, uint32_t input_b_id, uint32_t input_c_id, uint32_t output_id, ternary_op op, ternary_kernel_fn kernel); -void define_lut(ynn_subgraph& subgraph, ynn_node& node, uint32_t input_id, - uint32_t lut_id, uint32_t& output_id); bool define_dequantize_dot(ynn_subgraph& subgraph, ynn_node& node, ynn_type output_type, uint32_t dot_id, diff --git a/ynnpack/subgraph/fusion.cc b/ynnpack/subgraph/fusion.cc index 2b1858dfb3c..4b9a21c50ab 100644 --- a/ynnpack/subgraph/fusion.cc +++ b/ynnpack/subgraph/fusion.cc @@ -487,9 +487,11 @@ bool is_broadcast_noop(const ynn_subgraph& subgraph, const ynn_node& node, } } return true; - } else if (std::holds_alternative(node.op) && - input_id == node.inputs[0]) { - return true; + } else if (const auto* g = std::get_if(&node.op)) { + const ynn_value& table = subgraph.value(node.inputs[0]); + if (table.rank() == 1 && g->axis == 0 && input_id == node.inputs[1]) { + return true; + } } else if (const auto* t = std::get_if(&node.op)) { assert(input_id == node.inputs[0]); diff --git a/ynnpack/subgraph/fusion_lut.cc b/ynnpack/subgraph/fusion_lut.cc index 19c334ce7eb..567aaa3daf7 100644 --- a/ynnpack/subgraph/fusion_lut.cc +++ b/ynnpack/subgraph/fusion_lut.cc @@ -15,6 +15,7 @@ #include "ynnpack/include/ynnpack.h" #include "ynnpack/subgraph/elementwise.h" #include "ynnpack/subgraph/fusion_types.h" +#include "ynnpack/subgraph/gather.h" #include "ynnpack/subgraph/runtime.h" #include "ynnpack/subgraph/subgraph.h" #include "ynnpack/subgraph/utils.h" @@ -328,8 +329,8 @@ bool rewrite_subgraph_for_unary_lut(ynn_subgraph& subgraph, // Replace and invalidate all nodes in `candidate` with the new LUT node. ynn_node* output_node = analysis.producers[best_candidate.get_output_id()]; - define_lut(subgraph, *output_node, best_candidate.get_input_id(), lut_id, - best_candidate.output_id()); + define_gather(subgraph, *output_node, /*axis=*/0, lut_id, + best_candidate.get_input_id(), best_candidate.output_id()); return true; } diff --git a/ynnpack/subgraph/gather.cc b/ynnpack/subgraph/gather.cc new file mode 100644 index 00000000000..333e04dbaad --- /dev/null +++ b/ynnpack/subgraph/gather.cc @@ -0,0 +1,314 @@ +// Copyright 2026 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ynnpack/base/log.h" +#include "ynnpack/base/type.h" +#include "ynnpack/include/ynnpack.h" +#include "ynnpack/kernels/lut/lut.h" +#include "ynnpack/subgraph/runtime.h" +#include "ynnpack/subgraph/slinky.h" +#include "ynnpack/subgraph/subgraph.h" +#include "ynnpack/subgraph/utils.h" +#include "slinky/builder/pipeline.h" +#include "slinky/runtime/buffer.h" +#include "slinky/runtime/expr.h" +#include "slinky/runtime/stmt.h" + +namespace ynn { + +namespace { + +// Call a lut kernel. +auto make_lut_impl(lut_kernel_fn kernel) { + return [kernel](slinky::raw_buffer a, slinky::raw_buffer lut, + slinky::raw_buffer x) -> slinky::index_t { + assert(is_contiguous(x.dim(0), x.elem_size)); + + slinky::dim x_dim = ynn::slice_dim0(x); + if (x_dim.empty()) { + return 0; + } + slinky::dim a_dim = ynn::slice_dim0(a, slinky::in_bounds{x_dim.min()}); + + assert(is_contiguous(a_dim, a.elem_size)); + assert(is_contiguous(x_dim, x.elem_size)); + (void)a_dim; + + const slinky::index_t x_n_extent = x_dim.extent(); + + // Slice the lookup dimension of lut (dim 0). + lut.slice(0, 0); + + slinky::for_each_element( + [=](void* x, const void* a, const void* lut_ptr) { + kernel(x_n_extent, a, lut_ptr, x); + }, + x, a, lut); + return 0; + }; +} + +ynn_status create_gather_lut(const ynn_node& node, ynn_runtime& runtime, + lut_kernel_fn kernel) { + assert(node.inputs.size() == 2); + assert(node.outputs.size() == 1); + + const ynn_runtime_value& lut = runtime.value(node.inputs[0]); // table + const ynn_runtime_value& a = runtime.value(node.inputs[1]); // index + ynn_runtime_value& x = runtime.value(node.outputs[0]); + + x.make_buffer(runtime); + std::vector dims = runtime.globals.make_dims(x.rank()); + slinky::box_expr bounds = make_elementwise_bounds(dims, a.physical_extents()); + + slinky::box_expr lut_bounds(lut.rank()); + lut_bounds[0] = slinky::interval_expr(0, 1 << type_size_bytes(a.type)); + for (size_t d = 1; d < lut.rank(); ++d) { + lut_bounds[d] = elementwise_bounds(dims[d], lut.physical_extents()[d]); + } + + slinky::call_stmt::attributes attrs; + attrs.name = "lut"; + attrs.allow_in_place = compute_allow_in_place(node, *runtime.subgraph); + + auto func = slinky::func::make( + make_lut_impl(kernel), + {{a.buffer, std::move(bounds)}, {lut.buffer, std::move(lut_bounds)}}, + {{x.buffer, dims}}, std::move(attrs)); + + auto sched = + runtime.make_schedule(dims, x.physical_extents(), x.buffer->elem_size()); + func.user_data() = sched.get(); + runtime.scheduling_info_storage.push_back(std::move(sched)); + runtime.funcs.push_back(std::move(func)); + + return ynn_status_success; +} + +slinky::index_t read_index_value(const void* ptr, ynn_type type) { + switch (type) { + case ynn_type_int8: + return *reinterpret_cast(ptr); + case ynn_type_uint8: + return *reinterpret_cast(ptr); + case ynn_type_int32: + return *reinterpret_cast(ptr); + default: + assert(false && "Unsupported index type"); + return 0; + } +} + +auto make_gather_impl(int32_t axis, ynn_type index_type) { + return + [axis, index_type]( + slinky::buffer input, + slinky::buffer index, + slinky::buffer output) -> slinky::index_t { + slinky::dim input_axis = input.dim(axis); + + std::size_t elem_size = output.elem_size; + const size_t R_loop = index.rank; + const size_t R_rem = output.rank - R_loop; + + // 1. Modify input in place to crop the axis. + if (index.rank == 0) { + input.slice(axis, input_axis.min()); + } else { + input.crop(axis, input_axis.min(), input_axis.min()); + if (axis < input.rank) { + input.mutable_dim(axis).set_stride(0); + } + } + + // 3. Prepare slices for copy outside the loop. + // They represent the outermost R_rem dimensions. + slinky::raw_buffer output_slice; + output_slice.elem_size = elem_size; + output_slice.rank = R_rem; + output_slice.dims = (R_rem > 0) ? (output.dims + R_loop) : nullptr; + + slinky::raw_buffer input_slice; + input_slice.elem_size = elem_size; + input_slice.rank = (input.rank > static_cast(R_loop)) + ? (input.rank - static_cast(R_loop)) + : 0; + input_slice.dims = (input.rank > static_cast(R_loop)) + ? (input.dims + R_loop) + : nullptr; + + // Buffers for the for_each_element loop. + // They represent the innermost R_loop dimensions. + slinky::raw_buffer output_for_loop = output; + output_for_loop.dims = output.dims; + output_for_loop.rank = R_loop; + + slinky::raw_buffer input_for_loop = input; + input_for_loop.dims = input.dims; + input_for_loop.rank = + std::min(input.rank, static_cast(R_loop)); + + // 4. Run gather. + slinky::for_each_element( + [=, &output_slice, &input_slice](void* output_ptr, + const void* index_ptr, + const void* input_dummy_ptr) { + slinky::index_t idx = read_index_value(index_ptr, index_type); + if (idx < 0) { + idx += input_axis.extent(); + } + assert(idx >= 0 && idx < input_axis.extent()); + + const void* input_ptr = slinky::offset_bytes( + input_dummy_ptr, input_axis.flat_offset_bytes(idx)); + + output_slice.base = output_ptr; + input_slice.base = const_cast(input_ptr); + + slinky::copy(input_slice, output_slice); + }, + output_for_loop, index, input_for_loop); + + return 0; + }; +} + +} // namespace + +void define_gather(ynn_subgraph& subgraph, ynn_node& node, int32_t axis, + uint32_t input_id, uint32_t index_id, uint32_t& output_id) { + const ynn_value& input = subgraph.value(input_id); + const ynn_value& index = subgraph.value(index_id); + + ynn_value& output = subgraph.get_output_value(&output_id, input); + + node.inputs = {input_id, index_id}; + node.outputs = {output_id}; + node.op = ynn_node::gather{axis}; + + size_t output_rank = index.rank(); + for (size_t d = 0; d < input.rank(); ++d) { + if (d != static_cast(axis)) { + output_rank = std::max(output_rank, d + 1); + } + } + output.extents.resize(output_rank); + + for (size_t d = 0; d < index.rank(); ++d) { + subgraph.infer_elementwise_shape(node, /*input_idx=*/1, + /*output_idx=*/0, + /*input_dim=*/d, /*output_dim=*/d); + } + + for (size_t d = 0; d < input.rank(); ++d) { + if (d != static_cast(axis)) { + subgraph.infer_elementwise_shape(node, /*input_idx=*/0, + /*output_idx=*/0, + /*input_dim=*/d, /*output_dim=*/d); + } + } + + if (axis == 0) { + // If we are doing the gather in axis 0, we might be able to use a LUT + // kernel for this. + lut_kernel_fn kernel = get_lut_kernel(index.type, input.type); + if (kernel) { + node.create = [kernel](const ynn_node& node, ynn_runtime& runtime) { + return create_gather_lut(node, runtime, kernel); + }; + return; + } + } + + node.create = [](const ynn_node& node, ynn_runtime& runtime) { + int32_t axis = std::get(node.op).axis; + const ynn_runtime_value& input = runtime.value(node.inputs[0]); + const ynn_runtime_value& index = runtime.value(node.inputs[1]); + ynn_runtime_value& output = runtime.value(node.outputs[0]); + + output.make_buffer(runtime, input.buffer->elem_size()); + + std::vector dims = runtime.globals.make_dims(output.rank()); + + slinky::box_expr input_bounds(input.rank()); + for (size_t d = 0; d < input.rank(); ++d) { + if (d == static_cast(axis)) { + input_bounds[d] = all_bounds(input.physical_extents()[d]); + } else { + input_bounds[d] = + elementwise_bounds(dims[d], input.physical_extents()[d]); + } + } + + slinky::box_expr index_bounds(index.rank()); + for (size_t j = 0; j < index.rank(); ++j) { + index_bounds[j] = + elementwise_bounds(dims[j], index.physical_extents()[j]); + } + + auto func = slinky::func::make(make_gather_impl(axis, index.type), + {{input.buffer, std::move(input_bounds)}, + {index.buffer, std::move(index_bounds)}}, + {{output.buffer, dims}}); + + auto sched = runtime.make_schedule(dims, output.physical_extents(), + output.buffer->elem_size()); + func.user_data() = sched.get(); + runtime.scheduling_info_storage.push_back(std::move(sched)); + runtime.funcs.push_back(std::move(func)); + return ynn_status_success; + }; +} + +} // namespace ynn + +extern "C" { + +ynn_status ynn_define_gather(ynn_subgraph_t subgraph, int32_t axis, + uint32_t input_id, uint32_t index_id, + uint32_t* output_id, uint32_t flags) { + YNN_RETURN_IF_ERROR(ynn::validate_subgraph("gather", subgraph)); + YNN_RETURN_IF_ERROR( + ynn::validate_input_tensor("gather", subgraph, "input_id", input_id)); + YNN_RETURN_IF_ERROR( + ynn::validate_input_tensor("gather", subgraph, "index_id", index_id)); + YNN_RETURN_IF_ERROR( + ynn::validate_output_tensor("gather", subgraph, "output_id", output_id)); + const ynn_value& input = subgraph->value(input_id); + YNN_RETURN_IF_ERROR( + ynn::validate_axis("gather", "input", input.rank(), axis)); + const ynn_value& index = subgraph->value(index_id); + if (input.rank() > index.rank() && + !(index.rank() == 0 && input.rank() == 1)) { + YNN_LOG_ERROR() + << "For node `gather`, input rank must be less than or equal to index " + "rank (unless index is scalar and input is 1D). Got input rank " + << input.rank() << " and index rank " << index.rank(); + return ynn_status_invalid_parameter; + } + if (!ynn::type_is_integral(index.type)) { + YNN_LOG_ERROR() << "For node `gather`, index must be integral, got " + << index.type; + return ynn_status_invalid_parameter; + } + + axis = ynn::axis_to_slinky_dim(input.rank(), axis); + + ynn_node node; + ynn::define_gather(*subgraph, node, axis, input_id, index_id, *output_id); + subgraph->add_node(std::move(node)); + return ynn_status_success; +} + +} // extern "C" diff --git a/ynnpack/subgraph/gather.h b/ynnpack/subgraph/gather.h new file mode 100644 index 00000000000..58241f1e6ee --- /dev/null +++ b/ynnpack/subgraph/gather.h @@ -0,0 +1,20 @@ +// Copyright 2026 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#ifndef XNNPACK_YNNPACK_SUBGRAPH_GATHER_H_ +#define XNNPACK_YNNPACK_SUBGRAPH_GATHER_H_ + +#include + +#include "ynnpack/subgraph/subgraph.h" + +namespace ynn { + +void define_gather(ynn_subgraph& subgraph, ynn_node& node, int32_t axis, + uint32_t input_id, uint32_t index_id, uint32_t& output_id); + +} // namespace ynn + +#endif // XNNPACK_YNNPACK_SUBGRAPH_GATHER_H_ diff --git a/ynnpack/subgraph/subgraph.cc b/ynnpack/subgraph/subgraph.cc index 40b8b8d37ab..5cece3d3c56 100644 --- a/ynnpack/subgraph/subgraph.cc +++ b/ynnpack/subgraph/subgraph.cc @@ -952,7 +952,7 @@ const char* name_of(const ynn_node::opaque&) { return "opaque"; } const char* name_of(const ynn_node::unary_elementwise&) { return "unary_elementwise"; } -const char* name_of(const ynn_node::lut&) { return "lut"; } + const char* name_of(const ynn_node::binary_elementwise&) { return "binary_elementwise"; } @@ -967,6 +967,7 @@ const char* name_of(const ynn_node::concatenate&) { return "concatenate"; } const char* name_of(const ynn_node::stack&) { return "stack"; } const char* name_of(const ynn_node::even_split&) { return "even_split"; } const char* name_of(const ynn_node::copy&) { return "copy"; } +const char* name_of(const ynn_node::gather&) { return "gather"; } const char* name_of(const ynn_node::fuse_dim&) { return "fuse_dim"; } const char* name_of(const ynn_node::fuse_dims&) { return "fuse_dims"; } const char* name_of(const ynn_node::split_dim&) { return "split_dim"; } @@ -1058,7 +1059,6 @@ void print(std::ostream& os, const ynn_node::unary_elementwise& op) { os << "op=" << op.op; } -void print(std::ostream& os, const ynn_node::lut& op) {} void print(std::ostream& os, const ynn_node::binary_elementwise& op) { os << "op=" << op.op; @@ -1092,6 +1092,9 @@ void print(std::ostream& os, const ynn_node::even_split& op) { } void print(std::ostream& os, const ynn_node::copy& op) { os << "copy"; } +void print(std::ostream& os, const ynn_node::gather& op) { + os << "axis=" << op.axis; +} void print(std::ostream& os, const ynn_node::fuse_dim& op) { os << "axis=" << op.axis << " axes_count=" << op.axes_count; diff --git a/ynnpack/subgraph/subgraph.h b/ynnpack/subgraph/subgraph.h index f077714d4b9..e8a8572c636 100644 --- a/ynnpack/subgraph/subgraph.h +++ b/ynnpack/subgraph/subgraph.h @@ -287,10 +287,7 @@ struct ynn_node { return false; } }; - struct lut { - friend bool operator==(const lut&, const lut&) { return true; } - friend bool operator<(const lut&, const lut&) { return false; } - }; + struct binary_elementwise { ynn_binary_operator op; friend bool operator==(const binary_elementwise& a, @@ -317,6 +314,15 @@ struct ynn_node { friend bool operator==(const copy&, const copy&) { return true; } friend bool operator<(const copy&, const copy&) { return false; } }; + struct gather { + int32_t axis; + friend bool operator==(const gather& a, const gather& b) { + return a.axis == b.axis; + } + friend bool operator<(const gather& a, const gather& b) { + return a.axis < b.axis; + } + }; struct fuse_dim { // Fuse `axes_count` dimensions starting at `axis` into one dimension. int32_t axis; @@ -569,10 +575,10 @@ struct ynn_node { std::vector inputs; std::vector outputs; std::variant op; diff --git a/ynnpack/subgraph/test/BUILD b/ynnpack/subgraph/test/BUILD index 92273a8a204..8a17a991043 100644 --- a/ynnpack/subgraph/test/BUILD +++ b/ynnpack/subgraph/test/BUILD @@ -187,9 +187,9 @@ cc_test( "even_split", "fuse_dim", "fuse_dims", + "gather", "get_tensor_shape", "iota", - "lut", "reduce", "reduce_dot", "runtime", diff --git a/ynnpack/subgraph/test/fusion_lut.cc b/ynnpack/subgraph/test/fusion_lut.cc index 28ccb3a3cdd..7072dec67dc 100644 --- a/ynnpack/subgraph/test/fusion_lut.cc +++ b/ynnpack/subgraph/test/fusion_lut.cc @@ -109,9 +109,9 @@ TEST(fusion_lut, single_node_simple) { output_id, 256, [&](const ynn_subgraph& subgraph) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(1), HasValidValueIds(input_id, output_id))); - EXPECT_THAT( - ProducerOf(output_id, subgraph), - AllOf(IsLut(), InputsAre(input_id, IsValidValueIn(subgraph)))); + EXPECT_THAT(ProducerOf(output_id, subgraph), + AllOf(IsLut(subgraph), + InputsAre(IsValidValueIn(subgraph), input_id))); }); } @@ -157,8 +157,9 @@ TEST(fusion_lut, single_node) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(9), HasValidValueIds(a_id, b_id, c_id, d_id))); EXPECT_THAT(ProducerOf(x_id, subgraph), IsUnary(ynn_unary_convert)); - EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); + EXPECT_THAT( + ProducerOf(y_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); EXPECT_THAT(ProducerOf(d_id, subgraph), IsUnary(ynn_unary_convert)); }); } @@ -193,8 +194,9 @@ TEST(fusion_lut, multiple_unary_chain) { 256, [&](const ynn_subgraph& subgraph) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(1), HasValidValueIds(x_id, y_id))); - EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); + EXPECT_THAT( + ProducerOf(y_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); }); } @@ -260,8 +262,9 @@ TEST(fusion_lut, elu_chain) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(1), HasValidValueIds(x_id, y_id))); // LUT inputs: First is index (x), second is table (generated const). - EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); + EXPECT_THAT( + ProducerOf(y_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); }); } @@ -293,10 +296,12 @@ TEST(fusion_lut, branching_2_luts) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(2), HasValidValueIds(x_id, y_id, z_id), Not(HasValidValueId(t_id)))); - EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); - EXPECT_THAT(ProducerOf(z_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); + EXPECT_THAT( + ProducerOf(y_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); + EXPECT_THAT( + ProducerOf(z_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); }); } @@ -321,7 +326,7 @@ TEST(fusion_lut, input_type_unsupported) { HasValidValueIds(x_id, a_id, y_id))); EXPECT_THAT(ProducerOf(a_id, subgraph), IsUnary(ynn_unary_exp)); EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(Not(IsLut()), IsUnary(ynn_unary_convert))); + AllOf(Not(IsLut(subgraph)), IsUnary(ynn_unary_convert))); }); } @@ -357,8 +362,9 @@ TEST(fusion_lut, binary_scalar_constant) { 256, [&](const ynn_subgraph& subgraph) { ASSERT_THAT(subgraph, AllOf(HasValidNodeCount(1), HasValidValueIds(x_id, y_id))); - EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(IsLut(), InputsAre(x_id, IsValidValueIn(subgraph)))); + EXPECT_THAT( + ProducerOf(y_id, subgraph), + AllOf(IsLut(subgraph), InputsAre(IsValidValueIn(subgraph), x_id))); }); } @@ -397,7 +403,7 @@ TEST(fusion_lut, binary_nonscalar_constant_unsupported) { 256, [&](const ynn_subgraph& subgraph) { ASSERT_THAT(subgraph, HasValidValueIds(x_id, y_id, b_id)); EXPECT_THAT(ProducerOf(y_id, subgraph), - AllOf(Not(IsLut()), IsUnary(ynn_unary_convert))); + AllOf(Not(IsLut(subgraph)), IsUnary(ynn_unary_convert))); EXPECT_THAT(ProducerOf(b_id, subgraph), IsBinary(ynn_binary_add)); }); } diff --git a/ynnpack/subgraph/test/gather.cc b/ynnpack/subgraph/test/gather.cc new file mode 100644 index 00000000000..4053e52ae4e --- /dev/null +++ b/ynnpack/subgraph/test/gather.cc @@ -0,0 +1,152 @@ +// Copyright 2026 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include +#include "ynnpack/base/bfloat16.h" +#include "ynnpack/base/half.h" +#include "ynnpack/base/type.h" +#include "ynnpack/include/ynnpack.h" +#include "ynnpack/subgraph/test/subgraph_builder.h" + +namespace ynn { +namespace { + +template +void TestGather(int32_t axis, std::vector input_shape, + std::vector input_data, std::vector index_shape, + std::vector index_data, + std::vector expected_output_shape, + std::vector expected_output_data) { + SubgraphBuilder subgraph(3); + uint32_t input_id = 0; + uint32_t index_id = 1; + uint32_t output_id = 2; + + subgraph.AddInput(type_of(), input_shape, input_id) + .AddInput(type_of(), index_shape, index_id) + .AddOutput(type_of(), expected_output_shape, output_id) + .AddGather(axis, input_id, index_id, output_id); + + Runtime runtime(subgraph.GetSubgraph()); + ASSERT_EQ(runtime.Status(), ynn_status_success); + + runtime.ReshapeExternalTensor(input_shape, input_data.data(), input_id); + runtime.ReshapeExternalTensor(index_shape, index_data.data(), index_id); + runtime.ReshapeRuntime(); + + ASSERT_EQ(runtime.GetExternalTensorShape(output_id), expected_output_shape); + + std::vector output_data(expected_output_data.size()); + runtime.SetupExternalTensor(output_data.data(), output_id).InvokeRuntime(); + + EXPECT_THAT(output_data, testing::ElementsAreArray(expected_output_data)); +} + +template +class GatherTest : public ::testing::Test { + protected: + using InputType = typename std::tuple_element<0, T>::type; + using IndexType = typename std::tuple_element<1, T>::type; +}; + +using GatherTestTypes = + ::testing::Types, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, std::tuple, + std::tuple, + std::tuple, + std::tuple >; +TYPED_TEST_SUITE(GatherTest, GatherTestTypes); + +TYPED_TEST(GatherTest, Index0D) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // 4. 1D input, 0D index + TestGather( + /*axis=*/0, + /*input_shape=*/{3}, /*input_data=*/{1, 2, 3}, + /*index_shape=*/{}, /*index_data=*/{1}, + /*expected_output_shape=*/{}, /*expected_output_data=*/{2}); +} + +TYPED_TEST(GatherTest, Index1D) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // 5. 1D input, 1D index + TestGather( + /*axis=*/0, + /*input_shape=*/{3}, /*input_data=*/{1, 2, 3}, + /*index_shape=*/{2}, /*index_data=*/{2, 0}, + /*expected_output_shape=*/{2}, /*expected_output_data=*/{3, 1}); +} + +TYPED_TEST(GatherTest, Index2D) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // 6. 1D input, 2D index + TestGather( + /*axis=*/0, + /*input_shape=*/{3}, /*input_data=*/{1, 2, 3}, + /*index_shape=*/{2, 3}, /*index_data=*/{2, 0, 1, 1, 2, 0}, + /*expected_output_shape=*/{2, 3}, + /*expected_output_data=*/{3, 1, 2, 2, 3, 1}); +} + +TYPED_TEST(GatherTest, Input2DIndex2D) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // 9. 2D input, 2D index + TestGather( + /*axis=*/0, + /*input_shape=*/{3, 3}, /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9}, + /*index_shape=*/{2, 3}, /*index_data=*/{1, 0, 2, 1, 2, 0}, + /*expected_output_shape=*/{2, 3}, + /*expected_output_data=*/{4, 2, 9, 4, 8, 3}); +} + +TYPED_TEST(GatherTest, IndexBroadcasting) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // Index broadcasting: index shape {2, 1} broadcasted to match input shape {2, + // 3} (excluding axis 0) + TestGather( + /*axis=*/0, + /*input_shape=*/{3, 3}, /*input_data=*/{1, 2, 3, 4, 5, 6, 7, 8, 9}, + /*index_shape=*/{2, 1}, /*index_data=*/{2, 0}, + /*expected_output_shape=*/{2, 3}, + /*expected_output_data=*/{7, 8, 9, 1, 2, 3}); +} + +TYPED_TEST(GatherTest, InputBroadcasting) { + using T = typename TestFixture::InputType; + using IndexType = typename TestFixture::IndexType; + + // Input broadcasting: input shape {2, 1} broadcasted to match index shape {2, + // 3} (excluding axis 0) + TestGather( + /*axis=*/0, + /*input_shape=*/{3, 1}, /*input_data=*/{1, 2, 3}, + /*index_shape=*/{2, 3}, /*index_data=*/{2, 0, 1, 0, 2, 0}, + /*expected_output_shape=*/{2, 3}, + /*expected_output_data=*/{3, 1, 2, 1, 3, 1}); +} + +} // namespace +} // namespace ynn diff --git a/ynnpack/subgraph/test/lut.cc b/ynnpack/subgraph/test/lut.cc deleted file mode 100644 index 86660d867b6..00000000000 --- a/ynnpack/subgraph/test/lut.cc +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright 2025 Google LLC -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include "ynnpack/base/test/fuzz_test.h" -#include "ynnpack/base/test/random.h" -#include "ynnpack/base/test/tensor.h" -#include "ynnpack/base/type.h" -#include "ynnpack/include/ynnpack.h" -#include "ynnpack/subgraph/test/subgraph_builder.h" - -namespace ynn { -namespace { - -template -void TestLut() { - ReplicableRandomDevice rng; - std::uniform_int_distribution rank_dist(1, YNN_MAX_TENSOR_RANK); - for (auto _ : FuzzTest(std::chrono::milliseconds(100))) { - size_t rank = rank_dist(rng); - // We want the total number of elements to be reasonable, so choose max_dim - // such that a random shape of rank `p.rank` produces this max size. - constexpr size_t max_size = 1024; - const size_t max_dim = static_cast(std::ceil( - std::pow(static_cast(max_size), - 1.0 / static_cast(std::max(1, rank))))); - std::vector lut_data(256); - std::uniform_int_distribution lut_dist(std::numeric_limits::min(), - std::numeric_limits::max()); - std::generate(lut_data.begin(), lut_data.end(), - [&]() { return static_cast(lut_dist(rng)); }); - - SubgraphBuilder subgraph(2); - uint32_t input_id = 0; - uint32_t output_id = 1; - uint32_t lut_id = YNN_INVALID_VALUE_ID; - - std::vector input_shape = random_shape(rng, rank, 0, max_dim); - subgraph.AddInput(type_of(), input_shape, input_id) - .AddOutput(type_of(), rank, output_id) - .AddTensor(type_of(), {256}, lut_id, lut_data.data()); - - ASSERT_EQ( - ynn_define_lut(subgraph.GetSubgraph(), input_id, lut_id, &output_id, 0), - ynn_status_success); - - Runtime runtime(subgraph.GetSubgraph()); - ASSERT_EQ(runtime.Status(), ynn_status_success); - - for (int reshape = 0; reshape < 2; ++reshape) { - std::vector shape = random_shape(rng, input_shape, 1, max_dim); - - Tensor a(shape); - Tensor output(shape); - - std::uniform_int_distribution a_dist(std::numeric_limits::min(), - std::numeric_limits::max()); - std::generate(a.data(), a.data() + a.size(), - [&]() { return static_cast(a_dist(rng)); }); - - runtime.ReshapeExternalTensor(shape, a.data(), input_id).ReshapeRuntime(); - ASSERT_EQ(runtime.GetExternalTensorShape(output_id), shape); - runtime.SetupExternalTensor(output.data(), output_id).InvokeRuntime(); - - for (size_t i = 0; i < a.size(); ++i) { - size_t index = static_cast(a.data()[i]); - ASSERT_EQ(output.data()[i], lut_data[index]); - } - } - } -} - -TEST(LutTest, LutUint8) { TestLut(); } - -TEST(LutTest, LutInt8) { TestLut(); } - -} // namespace -} // namespace ynn diff --git a/ynnpack/subgraph/test/matchers.h b/ynnpack/subgraph/test/matchers.h index ef48b2bede0..40eee2219be 100644 --- a/ynnpack/subgraph/test/matchers.h +++ b/ynnpack/subgraph/test/matchers.h @@ -113,8 +113,14 @@ MATCHER_P(HasInputCount, count, "") { // Checks that the given node is a LUT. // // Example: -// EXPECT_THAT(ProducerOf(y_id, subgraph), IsLut()); -MATCHER(IsLut, "") { return std::holds_alternative(arg.op); } +// EXPECT_THAT(ProducerOf(y_id, subgraph), IsLut(subgraph)); +MATCHER_P(IsLut, subgraph, "") { + const auto* g = std::get_if(&arg.op); + if (!g) return false; + if (g->axis != 0) return false; + const ynn_value& table = subgraph.value(arg.inputs[0]); + return table.rank() == 1; +} // Checks that the given node is a binary elementwise with the given operator. // diff --git a/ynnpack/subgraph/test/subgraph_builder.cc b/ynnpack/subgraph/test/subgraph_builder.cc index 118d429d558..7e7b7825388 100644 --- a/ynnpack/subgraph/test/subgraph_builder.cc +++ b/ynnpack/subgraph/test/subgraph_builder.cc @@ -178,6 +178,16 @@ SubgraphBuilder& SubgraphBuilder::AddCopy(uint32_t input_id, uint32_t output_id, return *this; } +SubgraphBuilder& SubgraphBuilder::AddGather(int32_t axis, uint32_t input_id, + uint32_t index_id, + uint32_t output_id, + uint32_t flags) { + assert(status_ == ynn_status_success); + status_ = ynn_define_gather(subgraph_.get(), axis, input_id, index_id, + &output_id, flags); + return *this; +} + SubgraphBuilder& SubgraphBuilder::AddFuseDim(int32_t first_dim, size_t num_dims, uint32_t input_id, uint32_t output_id) { diff --git a/ynnpack/subgraph/test/subgraph_builder.h b/ynnpack/subgraph/test/subgraph_builder.h index f90cef50e02..13ccbb1f2fa 100644 --- a/ynnpack/subgraph/test/subgraph_builder.h +++ b/ynnpack/subgraph/test/subgraph_builder.h @@ -117,6 +117,8 @@ class SubgraphBuilder { uint32_t flags = 0); SubgraphBuilder& AddCopy(uint32_t input_id, uint32_t output_id, uint32_t flags = 0); + SubgraphBuilder& AddGather(int32_t axis, uint32_t input_id, uint32_t index_id, + uint32_t output_id, uint32_t flags = 0); SubgraphBuilder& AddFuseDim(int32_t first_dim, size_t num_dims, uint32_t input_id, uint32_t output_id); diff --git a/ynnpack/subgraph/utils.cc b/ynnpack/subgraph/utils.cc index 7384f5f2c33..d8a3926ac30 100644 --- a/ynnpack/subgraph/utils.cc +++ b/ynnpack/subgraph/utils.cc @@ -234,4 +234,15 @@ bool allow_in_place(uint32_t input_id, uint32_t output_id, return true; } +int compute_allow_in_place(const ynn_node& node, const ynn_subgraph& subgraph) { + assert(node.outputs.size() == 1); + int result = 0; + for (int i = 0; i < node.inputs.size(); ++i) { + if (allow_in_place(node.inputs[i], node.outputs[0], subgraph)) { + result |= 1 << i; + } + } + return result; +} + } // namespace ynn diff --git a/ynnpack/subgraph/utils.h b/ynnpack/subgraph/utils.h index 752128e24ab..ecc0e16cc18 100644 --- a/ynnpack/subgraph/utils.h +++ b/ynnpack/subgraph/utils.h @@ -7,7 +7,6 @@ #define XNNPACK_YNNPACK_SUBGRAPH_UTILS_H_ #include -#include #include "ynnpack/subgraph/subgraph.h" @@ -17,6 +16,9 @@ namespace ynn { bool allow_in_place(uint32_t input_id, uint32_t output_id, const ynn_subgraph& subgraph); +// Computes the slinky::call_stmt::attributes::allow_in_place mask for a node. +int compute_allow_in_place(const ynn_node& node, const ynn_subgraph& subgraph); + // Clone a subset of the subgraph that is required to compute `output_id` from // `input_id`. //