Submission 2025-05-15

Neon Batch-Reduce GEMM

This section considers a batch-reduce matrix-matrix multiplication that has a fourth dimension in addition to the known M, N, and K dimensions.

1. Implement a batch-reduce kernel matmul_64_48_64_16

File: neon_6_1.s

We started by implementing a kernel matmul_64_48_64 with a batch dimension of one which is in the file neon_6_1_batch1.s.

 1...
 2    mov x17, #12 // x17 iterator for N loop
 3matmul_loop_over_N:
 4    sub x17, x17, #1
 5
 6    ...
 7
 8    mov x16, #4 // x16 iterator for M loop
 9matmul_loop_over_M:
10    sub x16, x16, #1
11
12    ...
13
14    mov x15, #64 // x15 iterator for K loop
15matmul_loop_over_K:
16    sub x15, x15, #1
17
18    ... matmul_16_4_1 kernel ...
19
20    // Loop back to K
21    cbnz x15, matmul_loop_over_K
22
23    ...
24
25    // Loop back to M
26    cbnz x16, matmul_loop_over_M
27
28    ...
29
30    // Loop back to N
31    cbnz x17, matmul_loop_over_N

Then we wrapped the matmul_64_48_64 kernel inside another batch loop of size 16:

 1...
 2    mov x19, #16 // x19 iterator for the batch dimension
 3matmul_loop_batch_dimension:
 4    sub x19, x19, #1
 5
 6    ...
 7
 8    mov x17, #12 // x17 iterator for N loop
 9matmul_loop_over_N:
10    sub x17, x17, #1
11
12    ...
13
14    mov x16, #4 // x16 iterator for M loop
15matmul_loop_over_M:
16    sub x16, x16, #1
17
18    ...
19
20    mov x15, #64 // x15 iterator for K loop
21matmul_loop_over_K:
22    sub x15, x15, #1
23
24    ...
25
26    // Loop back to K
27    cbnz x15, matmul_loop_over_K
28
29    ... matmul_16_4_1 kernel ...
30
31    // Loop back to M
32    cbnz x16, matmul_loop_over_M
33
34    ...
35
36    // Loop back to N
37    cbnz x17, matmul_loop_over_N
38
39    ...
40
41    // Loop back to batch dimension
42    cbnz x19, matmul_loop_batch_dimension

2. Test and optimize

We tested a variation in which the batch loop was positioned between the M and K loops. This approach achieved around \(73\) GFLOPS. We suspect that the reason for this was that the matrices did not fit into the cache. We do not follow this approach due to the poor performance, and we lost the file due to a false rm statement.

However, this leads us to assume that our result of putting the batch loop outside is satisfactory.

-----------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                          Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------
GemmMxNxKxBatchFixture<64, 48, 64, 1>/BM_matmul_64_48_64/min_warmup_time:1.000_mean             3104 ns         3093 ns           10 127.138G/s
GemmMxNxKxBatchFixture<64, 48, 64, 1>/BM_matmul_64_48_64/min_warmup_time:1.000_median           3102 ns         3092 ns           10  127.19G/s
GemmMxNxKxBatchFixture<64, 48, 64, 1>/BM_matmul_64_48_64/min_warmup_time:1.000_stddev           10.1 ns         8.08 ns           10 331.319M/s
GemmMxNxKxBatchFixture<64, 48, 64, 1>/BM_matmul_64_48_64/min_warmup_time:1.000_cv               0.33 %          0.26 %            10      0.26%
GemmMxNxKxBatchFixture<64, 48, 64, 16>/BM_matmul_64_48_64_16/min_warmup_time:1.000_mean        51072 ns        50890 ns           10 123.628G/s
GemmMxNxKxBatchFixture<64, 48, 64, 16>/BM_matmul_64_48_64_16/min_warmup_time:1.000_median      51027 ns        50840 ns           10 123.749G/s
GemmMxNxKxBatchFixture<64, 48, 64, 16>/BM_matmul_64_48_64_16/min_warmup_time:1.000_stddev        120 ns          119 ns           10 287.993M/s
GemmMxNxKxBatchFixture<64, 48, 64, 16>/BM_matmul_64_48_64_16/min_warmup_time:1.000_cv           0.24 %          0.23 %            10      0.23%
  • matmul_64_48_64 kernel: \(127.1\) GFLOPS

  • matmul_64_48_64_16 kernel: \(123.6\) GFLOPS

GEMM

1. Extend generate to support M-N-K combinations for column-major format \(1 \leq M,N \leq 1024, 1 \leq K \leq 2028\)

To support all combinations of M, N and K, we use one kernel as a base and dynamically generate the rest of the handling for numbers that are not multiples of M, N or K. As a base we took the matmul_16m_4n_k kernel, which reached around 130 GFLOPS as 64_48_64 kernel (i.e. the same as the kernel from the previous section, with a batch dimension of one.). The k dimension is always a multiple of 1 therefore we don’t need a special case for this dimension. To get full coverage on the remaining dimension, we implemented the following variations:

  • matmul_16m_lt4nRest_k:
    • M dimension must be multiple of 16

    • N dimension can be less than 4 or larger, multiple of 4 are processed at once, N mod 4 are processed at the end at once

  • matmul_16mRest_4n_k:
    • M dimension can be larger than 16, multiple of 16 are processed at once, M mod 16 are processed at the end at once

    • N dimension must be multiple of 4

  • matmul_16mRest_lt4nRest_k:
    • M dimension can be larger than 16, multiple of 16 are processed at once, M mod 16 are processed at the end at once

    • N dimension can be less than 4 or larger, multiple of 4 are processed at once, N mod 4 are processed at the end at once

  • matmul_lt16_4n_k:
    • M dimension must be less than 16

    • N dimension must be multiple of 4

  • matmul_lt16_lt4nRest_k:
    • M dimension must be less than 16

    • N dimension can be less than 4 or larger, multiple of 4 are processed at once, N mod 4 are processed at the end at once

Together with the matmul_16m_4n_k, we have 6 kernels to cover the complete dimension space.

../_images/matmul_coverage_light.svg ../_images/matmul_coverage_dark.svg

2. Verify all matrices for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],``lda=M``, ldb=K, and ldc=M

All GEMM generation and execution using this configuration works with counting upwards and random data.

3. Verify all matrices for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],``lda>M``, ldb>K, and ldc>M

All GEMM generation and execution using this configuration works with counting upwards and random data.

4. Benchmark for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],``lda=M``, ldb=K, and ldc=M.

The benchmark took approximately eight hours in total to run. The following results were produced: GEMM_benchmarks.csv

Batch-Reduce GEMM

1. Extend generate to support batch dimension 1≤batch_size≤1024

In order to support an additional batch dimension in our implemented kernels, we placed all kernels within an additional batch loop. Consequently, the logic in our Brgemm.cpp was extended to check whether the batch dimension is greater than one.

 1...
 2if (dtype != dtype_t::fp32)
 3{
 4  return error_t::err_wrong_dtype;
 5}
 6if (m == 0 || n == 0 || k == 0)
 7{
 8  return error_t::err_wrong_dimension;
 9}
10if ((trans_a + trans_b + trans_c) != 0)
11{
12  return error_t::err_row_major_order_not_supported;
13}
14
15if (br_size == 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
16{
17  fill_with_matmuls_no_batch_dim_column_major_fp32(m, n, k);
18}
19else if (br_size > 1 && (trans_a + trans_b + trans_c) == 0 && dtype == dtype_t::fp32)
20{
21  fill_with_matmuls_batch_dim_column_major_fp32(m, n, k, br_size);
22}
23else
24{
25  throw std::logic_error(
26    std::format("Unhandled parameter combination found: m='{}', n='{}', k='{}', br_size='{}', trans_a='{}', trans_b='{}', "
27                "trans_c = '{}', dtype = '{}'",
28                m, n, k, br_size, trans_a, trans_b, trans_c, static_cast<int32_t>(dtype)));
29}
30...

This else if branch distributes to our extended br_matmul_* kernels with a larger batch dimension.

  • br_matmul_16m_lt4nRest_k

  • br_matmul_16mRest_4n_k

  • br_matmul_16mRest_lt4nRest_k

  • br_matmul_lt16_4n_k

  • br_matmul_lt16_lt4nRest_k

2. Verify against reference implementation

All kernels were tested. The tests are located in the file src/test/kernels/br_matmul_*.test.cpp.

The batched MatMul generation was tested for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128], 1≤BatchSize≤16, lda=M, ldb=K, and ldc=M. The test is located in the file src/test/Brgemm.test.cpp.

3. Benchmark for 1≤M≤64, 1≤N≤64, K∈[1,16,32,64,128],lda=M, ldb=K,ldc=M , batch_size=16

The benchmark took approximately eight hours in total to run. The following results were produced: GEMM_benchmarks.csv