Submission 2025-05-08¶
SIMD Lanes¶
This section considers matrix-matrix multiplications, that require instructions where only a subset of SIMD lanes are active.
1. Implement a kernel for M=14, N=6 and K=64 and wrap it in the matmul_14_6_64 function¶
File: neon_4_1.s
For this kernel matmul_14_6_64
we adapt the already implemented kernel matmul_16_6_64
. The only change is that we now use 3 fmla
instructions that operate on 4 scalars, and one fmla
instruction that only uses 2 scalars: \(4 \cdot 3 + 1 \cdot 2 = 14\).
We load the full 16 floats and ignore the last 2:
1...
2// Load first column from the 14x6 matrix c - load full 16 entries - ignore last 2
3ld1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
4// Load second column from the 14x6 matrix c
5ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
6// Load third column from the 14x6 matrix c
7ld1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
8// Load fourth column from the 14x6 matrix c
9ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5
10// Load fifth column from the 14x6 matrix c
11ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5
12// Load sixth column from the 14x6 matrix c
13ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], x5
14...
Next the loop over K:
1...
2 mov x9, #64 // x9 iterator for K loop
3matmul_loop_over_K:
4 sub x9, x9, #1
5
6 // Load first column data from the 14x1 matrix a (again 16 but we'll only using two from v3)
7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], x3
8
9 // run the known matmul_16_6_1_unrolled kernel with modification to matmult_14_6_1
10 // Load first element from the 1x6 matrix b
11 ldr s4, [x1]
12 add x1, x1, x4
13
14 // Calculate first column of c
15 fmla v25.4s, v0.4s, v4.s[0] // 4 floats
16 fmla v26.4s, v1.4s, v4.s[0] // 4 floats
17 fmla v27.4s, v2.4s, v4.s[0] // 4 floats
18 fmla v28.2s, v3.2s, v4.s[0] // 2 floats
19
20 // Load second element from the 1x6 matrix b
21 ldr s4, [x1]
22 add x1, x1, x4
23
24 // Calculate second column of c
25 fmla v17.4s, v0.4s, v4.s[0]
26 fmla v18.4s, v1.4s, v4.s[0]
27 fmla v19.4s, v2.4s, v4.s[0]
28 fmla v20.2s, v3.2s, v4.s[0]
29...
We store the full 16 computed floats back to memory but only add an offset of 14 floats because the last two floats aren’t used. The last 14 values are exactly stored (8+4+2).
1...
2// Store first column back to memory
3st1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5 // offset of 14 floats
4// Store second column back to memory
5st1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5 // offset of 14 floats
6// Store third column back to memory
7st1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5 // offset of 14 floats
8// Store fourth column back to memory
9st1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5 // offset of 14 floats
10// Store fifth column back to memory
11st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5 // offset of 14 floats
12// Store sixth column back to memory (exactly last 14 elements)
13stp q13, q14, [x2] // 8 floats
14str q15, [x2, #32] // 4 floats
15str d16, [x2, #48] // 2 floats
16...
2. Implement a kernel for M=15, N=6 and K=64 and wrap it in the matmul_15_6_64 function¶
File: neon_4_2.s
For this kernel matmul_15_6_64
we adapt the already implemented kernel matmul_16_6_64
. The only change is that we ignore the last computed float value from the 4 fmla
instructions when saving back to memory.
We load the full 16 floats and ignore the last one:
1...
2// Load first column from the 15x6 matrix c - load full 16 entries - ignore last
3ld1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
4// Load second column from the 15x6 matrix c
5ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
6// Load third column from the 15x6 matrix c
7ld1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
8// Load fourth column from the 15x6 matrix c
9ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5
10// Load fifth column from the 15x6 matrix c
11ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5
12// Load sixth column from the 15x6 matrix c
13ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], x5
14...
Next the loop over K:
1...
2 mov x9, #64 // x9 iterator for K loop
3matmul_loop_over_K:
4 sub x9, x9, #1
5
6 // Load first column data from the 15x1 matrix a
7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], x3
8 // ldp q0, q1, [x0] // 4 + 4 values
9 // ldr q2, [x0, #32] // 4 values
10 // ldr d3, [x0, #48] // 2 values
11
12 // run the known matmul_16_6_1_unrolled kernel with modification to matmult_15_6_1
13 // Load first element from the 1x6 matrix b
14 ldr s4, [x1]
15 add x1, x1, x4
16
17 // Calculate first column of c
18 fmla v25.4s, v0.4s, v4.s[0]
19 fmla v26.4s, v1.4s, v4.s[0]
20 fmla v27.4s, v2.4s, v4.s[0]
21 fmla v28.4s, v3.4s, v4.s[0]
22
23 // Load second element from the 1x6 matrix b
24 ldr s4, [x1]
25 add x1, x1, x4
26
27 // Calculate second column of c
28 fmla v17.4s, v0.4s, v4.s[0]
29 fmla v18.4s, v1.4s, v4.s[0]
30 fmla v19.4s, v2.4s, v4.s[0]
31 fmla v20.4s, v3.4s, v4.s[0]
32...
We store the full 16 computed floats back to memory but only add an offset of 15 floats because the last float isn’t used. The last 15 values are exactly stored (8+4+2+1).
1...
2// Store first column back to memory
3st1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5 // offset of 15 floats
4// Store second column back to memory
5st1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5 // offset of 15 floats
6// Store third column back to memory
7st1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5 // offset of 15 floats
8// Store fourth column back to memory
9st1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5 // offset of 15 floats
10// Store fifth column back to memory
11st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5 // offset of 15 floats
12// Store sixth column back to memory (exactly last 15 elements)
13stp q13, q14, [x2] // 8 floats
14str q15, [x2, #32] // 4 floats
15str d16, [x2, #48] // 2 floats
16mov w9, v16.s[2]
17str w9, [x2, #56] // 1 floats
18...
3. Test and optimize the kernels. Report your performance in GFLOPS¶
Optimized benchmark results:
--------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
--------------------------------------------------------------------------------------------------------------------------------------------
GemmMxNxKFixture<14, 6, 64>/BM_matmul_14_6_64/min_warmup_time:1.000_mean 94.8 ns 94.5 ns 10 113.789G/s
GemmMxNxKFixture<14, 6, 64>/BM_matmul_14_6_64/min_warmup_time:1.000_median 94.8 ns 94.5 ns 10 113.775G/s
GemmMxNxKFixture<14, 6, 64>/BM_matmul_14_6_64/min_warmup_time:1.000_stddev 0.671 ns 0.659 ns 10 790.609M/s
GemmMxNxKFixture<14, 6, 64>/BM_matmul_14_6_64/min_warmup_time:1.000_cv 0.71 % 0.70 % 10 0.69%
GemmMxNxKFixture<15, 6, 64>/BM_matmul_15_6_64/min_warmup_time:1.000_mean 95.5 ns 95.1 ns 10 121.074G/s
GemmMxNxKFixture<15, 6, 64>/BM_matmul_15_6_64/min_warmup_time:1.000_median 95.4 ns 95.1 ns 10 121.09G/s
GemmMxNxKFixture<15, 6, 64>/BM_matmul_15_6_64/min_warmup_time:1.000_stddev 0.295 ns 0.293 ns 10 373.529M/s
GemmMxNxKFixture<15, 6, 64>/BM_matmul_15_6_64/min_warmup_time:1.000_cv 0.31 % 0.31 % 10 0.31%
matmul_14_6_64 kernel: \(113.8\) GFLOPS
matmul_15_6_64 kernel: \(121.1\) GFLOPS
Accumulator Block Shapes¶
This section considers a matrix-matrix multiplication where a high-performance implementation may require accumulator blocks with different shapes.
1. Implement a kernel for M=15, N=6 and K=64 and wrap it in the matmul_64_64_64 function¶
File: neon_5_1.s
For this kernel matmul_64_64_64
we adapt the already implemented kernel matmul_64_48_64
. The only changes is that we removed two fmla
blocks from the inner loop:
1...
2 mov x15, #64 // x15 iterator for K loop
3matmul_loop_over_K:
4 sub x15, x15, #1
5
6 // Load first column data from the 16x1 matrix a
7 ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], x3
8
9 // run the matmul_16_4_1_unrolled kernel
10 // Load first element from the 1x4 matrix b
11 ldr s4, [x1]
12 add x1, x1, x4
13
14 // Calculate first column of c
15 fmla v25.4s, v0.4s, v4.s[0]
16 fmla v26.4s, v1.4s, v4.s[0]
17 fmla v27.4s, v2.4s, v4.s[0]
18 fmla v28.4s, v3.4s, v4.s[0]
19
20
21 // Load second element from the 1x4 matrix b
22 ldr s4, [x1]
23 add x1, x1, x4
24
25 // Calculate second column of c
26 fmla v17.4s, v0.4s, v4.s[0]
27 fmla v18.4s, v1.4s, v4.s[0]
28 fmla v19.4s, v2.4s, v4.s[0]
29 fmla v20.4s, v3.4s, v4.s[0]
30
31
32 // Load third element from the 1x4 matrix b
33 ldr s4, [x1]
34 add x1, x1, x4
35
36 // Calculated third column of c
37 fmla v21.4s, v0.4s, v4.s[0]
38 fmla v22.4s, v1.4s, v4.s[0]
39 fmla v23.4s, v2.4s, v4.s[0]
40 fmla v24.4s, v3.4s, v4.s[0]
41
42
43 // Load fourth element from the 1x4 matrix b
44 ldr s4, [x1]
45 add x1, x1, x4
46
47 // Calculate fourth column of c
48 fmla v5.4s, v0.4s, v4.s[0]
49 fmla v6.4s, v1.4s, v4.s[0]
50 fmla v7.4s, v2.4s, v4.s[0]
51 fmla v8.4s, v3.4s, v4.s[0]
52
53
54 // offset x6 to the next element in the column
55 add x6, x6, #4 // #4 = sizeof(float)
56
57 // Restore x1 to be incremented again
58 mov x1, x6
59
60 // Loop back to K
61 cbnz x15, matmul_loop_over_K
62...
Then changed the number of loops over M to four \(4 \cdot 16 = 64\):
1...
2 mov x16, #4 // x16 iterator for M loop
3matmul_loop_over_M:
4 sub x16, x16, #1
5
6 // Load first column from the 16x6 matrix c
7 ld1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
8 // Load second column from the 16x6 matrix c
9 ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
10 // Load third column from the 16x6 matrix c
11 ld1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
12 // Load fourth column from the 16x6 matrix c
13 ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5
14
15 mov x15, #64 // x15 iterator for K loop
16matmul_loop_over_K:
17 sub x15, x15, #1
18...
And finaly changed the number of loops over N to 16 \(16 \cdot 4 = 64\):
1...
2 mov x17, #16 // x17 iterator for N loop
3matmul_loop_over_N:
4 sub x17, x17, #1
5
6 mov x16, #4 // x16 iterator for M loop
7matmul_loop_over_M:
8 sub x16, x16, #1
9...
2. Test and optimize the kernel. Report your performance in GFLOPS¶
Optimized benchmark result:
--------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
--------------------------------------------------------------------------------------------------------------------------------------------
GemmMxNxKFixture<64, 64, 64>/BM_matmul_64_64_64/min_warmup_time:1.000_mean 4111 ns 4097 ns 10 127.964G/s
GemmMxNxKFixture<64, 64, 64>/BM_matmul_64_64_64/min_warmup_time:1.000_median 4110 ns 4096 ns 10 127.988G/s
GemmMxNxKFixture<64, 64, 64>/BM_matmul_64_64_64/min_warmup_time:1.000_stddev 13.7 ns 13.8 ns 10 431.794M/s
GemmMxNxKFixture<64, 64, 64>/BM_matmul_64_64_64/min_warmup_time:1.000_cv 0.33 % 0.34 % 10 0.34%
matmul_64_64_64 kernel: \(128.0\) GFLOPS
Microkernel¶
1. Implement generate function, support only the setting of an FP32 microkernel for C+=AB for M=16, N=6, K=1 and test for errors¶
Each instruction we generate gets a wrapper which is based on the following structure:
First asserts are placed to check if the instruction is used correctly to evade most errors from usage.
The instruction is build using masking operation and shifts to the starting bit of the opcode “block”.
1constexpr uint32_t ldrImmediatePost(const uint32_t Rt, const uint32_t Rn, const int32_t imm9, const bool is64bit)
2{
3 release_assert(((Rt & mask5) == Rt), "Rt is only allowed to have a size of 5 bit.");
4 release_assert(((Rn & mask5) == Rn), "Rn is only allowed to have a size of 5 bit.");
5 release_assert(imm9 <= 255, "imm9 has a Maximum of 255");
6 release_assert(imm9 >= -256, "imm9 has a Minimum of -256");
7
8 uint32_t ldr = 0;
9 ldr |= 0b1 << 31; // size bit 31
10 ldr |= (is64bit & mask1) << 30;
11 ldr |= 0b111000010 << 21; // opc 29 - 21
12 ldr |= (imm9 & mask9) << 12;
13 ldr |= 0b01 << 10; // opc 11 - 10
14 ldr |= (Rn & mask5) << 5;
15 ldr |= (Rt & mask5) << 0;
16 return ldr;
17}
This function then gets wrapped to match the definition of our enum class for each register.
1constexpr uint32_t ldrPost(const R32Bit Wt, const R64Bit Xn, const int32_t imm9)
2{
3 return internal::ldrImmediatePost(static_cast<uint32_t>(Wt), static_cast<uint32_t>(Xn), imm9, false);
4}
5
6constexpr uint32_t ldrPost(const R64Bit Xt, const R64Bit Xn, const int32_t imm9)
7{
8 return internal::ldrImmediatePost(static_cast<uint32_t>(Xt), static_cast<uint32_t>(Xn), imm9, true);
9}
Note
All these function have the keyword constexpr at the start. This has the benefit that most processing of the instruction can be done at compile time. Such that we get the following assembly code is produces:
All inputs are fixed, it compiles into 2 mov instructions
mov w1, #38073 // =0x94b9 movk w1, #63557, lsl #16
On input is known at runtime, it compiles into 3 instructions
mov w1, #38048 // =0x94a0 movk w1, #63557, lsl #16 bfxil x1, x8, #0, #5
Thus we do speedup the creation of the code generation, as most commands are known at compile time.
After writing a lot wrappers around the arm instructions.
We can translate our previous assembly written kernel using c++
function and generate the matmul_16_6_1 at runtime.
1void mini_jit::kernels::matmul_16_6_1(mini_jit::Kernel &kernel)
2{
3 using namespace mini_jit::arm_instructions;
4
5 kernel.add({
6 // Offset the used leading dimension by the size of floats
7 lsl(x3, x3, 2), // lsl x3, x3, #2
8 lsl(x4, x4, 2), // lsl x4, x4, #2
9 lsl(x5, x5, 2), // lsl x5, x5, #2
10
11 // Load all data from the 16x1 matrix a
12 ld1(v0, t4s, v1, t4s, v2, t4s, v3, t4s, x0) // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0]
13 });
14
15 for (int i = 0; i < 2; i++)
16 {
17 kernel.add({
18 // Load first element from the 1x6 matrix b
19 ldr(s4, x1), // ldr s4, [x1] WARNING
20 add(x1, x1, x4), // add x1, x1, x4
21 // Load first column from the 16x6 matrix c
22 ld1(v25, t4s, v26, t4s, v27, t4s, v28, t4s, x2), // ld1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2]
23
24 // Calculate first column of c
25 fmla(v25, t4s, v0, t4s, v4, 0), // fmla v25.4s, v0.4s, v4.s[0]
26 fmla(v26, t4s, v1, t4s, v4, 0), // fmla v26.4s, v1.4s, v4.s[0]
27 fmla(v27, t4s, v2, t4s, v4, 0), // fmla v27.4s, v2.4s, v4.s[0]
28 fmla(v28, t4s, v3, t4s, v4, 0), // fmla v28.4s, v3.4s, v4.s[0]
29
30 // Store first column back to memory
31 st1Post(v25, t4s, v26, t4s, v27, t4s, v28, t4s, x2, x5), // st1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
32
33 // Load second element from the 1x6 matrix b
34 ldr(s4, x1), // ldr s4, [x1]
35 add(x1, x1, x4), // add x1, x1, x4
36 // Load second column from the 16x6 matrix c
37 ld1(v17, t4s, v18, t4s, v19, t4s, v20, t4s, x2), // ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2]
38
39 // Calculate second column of c
40 fmla(v17, t4s, v0, t4s, v4, 0), // fmla v17.4s, v0.4s, v4.s[0]
41 fmla(v18, t4s, v1, t4s, v4, 0), // fmla v18.4s, v1.4s, v4.s[0]
42 fmla(v19, t4s, v2, t4s, v4, 0), // fmla v19.4s, v2.4s, v4.s[0]
43 fmla(v20, t4s, v3, t4s, v4, 0), // fmla v20.4s, v3.4s, v4.s[0]
44
45 // Store second column back to memory
46 st1Post(v17, t4s, v18, t4s, v19, t4s, v20, t4s, x2, x5), // st1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
47
48 // Load third element from the 1x6 matrix b
49 ldr(s4, x1), // ldr s4, [x1]
50 add(x1, x1, x4), // add x1, x1, x4
51 // Load third column from the 16x6 matrix c
52 ld1(v21, t4s, v22, t4s, v23, t4s, v24, t4s, x2), // ld1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2]
53
54 // Calculated third column of c
55 fmla(v21, t4s, v0, t4s, v4, 0), // fmla v21.4s, v0.4s, v4.s[0]
56 fmla(v22, t4s, v1, t4s, v4, 0), // fmla v22.4s, v1.4s, v4.s[0]
57 fmla(v23, t4s, v2, t4s, v4, 0), // fmla v23.4s, v2.4s, v4.s[0]
58 fmla(v24, t4s, v3, t4s, v4, 0), // fmla v24.4s, v3.4s, v4.s[0]
59
60 // Store third column back to memory
61 st1Post(v21, t4s, v22, t4s, v23, t4s, v24, t4s, x2, x5), // st1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
62 });
63 }
64
65 kernel.add(ret()); // ret
66
67 kernel.write("matmul_16_6_1.bin");
68}
In the original assembly we used .rept 2
.
We can replicate that using a simple for loop
.
Note
The kernel has two add functions. One for adding an uint32_t
and one to add a vector<uint32_t>
.
To reduce writing overhead of kernel.add
.
2. Add support for k parameter by generating a K loop around the microkernel¶
Adding support for the k parameter does require adding more wrapped instructions.
But then we can port our written assembly kernel to c++
and jit the k loop parameter.
1void mini_jit::kernels::matmul_16_6_k(mini_jit::Kernel &kernel, const uint32_t k_loop)
2{
3 using namespace mini_jit::arm_instructions;
4
5 // Procedural Call Standard
6 // save frame pointer and link register
7 kernel.add({
8
9 stpPre(fp, lr, sp, -16), // stp fp, lr, [sp, #-16]!
10 // update frame pointer to current stack pointer
11 movSp(fp, sp), // mov fp, sp
12
13 // save callee-saved registers
14 stpPre(x19, x20, sp, -16), // stp x19, x20, [sp, #-16]!
15 stpPre(x21, x22, sp, -16), // stp x21, x22, [sp, #-16]!
16 stpPre(x23, x24, sp, -16), // stp x23, x24, [sp, #-16]!
17 stpPre(x25, x26, sp, -16), // stp x25, x26, [sp, #-16]!
18 stpPre(x27, x28, sp, -16), // stp x27, x28, [sp, #-16]!
19
20 stpPre(d8, d9, sp, -16), // stp d8, d9, [sp, #-16]!
21 stpPre(d10, d11, sp, -16), // stp d10, d11, [sp, #-16]!
22 stpPre(d12, d13, sp, -16), // stp d12, d13, [sp, #-16]!
23 stpPre(d14, d15, sp, -16), // stp d14, d15, [sp, #-16]!
24
25 // Offset the used leading dimension by the size of floats
26 lsl(x3, x3, 2), // lsl x3, x3, #2
27 lsl(x4, x4, 2), // lsl x4, x4, #2
28 lsl(x5, x5, 2), // lsl x5, x5, #2
29
30 mov(x6, x1), // mov x6, x1
31 mov(x7, x2), // mov x7, x2
32
33 // Load first column from the 16x6 matrix c
34 ld1Post(v25, t4s, v26, t4s, v27, t4s, v28, t4s, x2, x5), // ld1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
35 // Load second column from the 16x6 matrix c
36 ld1Post(v17, t4s, v18, t4s, v19, t4s, v20, t4s, x2, x5), // ld1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
37 // Load third column from the 16x6 matrix c
38 ld1Post(v21, t4s, v22, t4s, v23, t4s, v24, t4s, x2, x5), // ld1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
39 // Load fourth column from the 16x6 matrix c
40 ld1Post(v5, t4s, v6, t4s, v7, t4s, v8, t4s, x2, x5), // ld1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5
41 // Load fifth column from the 16x6 matrix c
42 ld1Post(v9, t4s, v10, t4s, v11, t4s, v12, t4s, x2, x5), // ld1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5
43 // Load sixth column from the 16x6 matrix c
44 ld1Post(v13, t4s, v14, t4s, v15, t4s, v16, t4s, x2, x5), // ld1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], x5
45
46 movz(x9, k_loop), // mov x9, "iterator for K loop"
47
48 // #############################
49 // #### matmul_loop_over_K: ####
50 // #############################
51 sub(x9, x9, 1), // sub x9, x9, #1
52
53 // Load first column data from the 16x1 matrix a
54 ld1Post(v0, t4s, v1, t4s, v2, t4s, v3, t4s, x0, x3), // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], x3
55
56 // run the known matmul_16_6_1_unrolled kernel
57 // Load first element from the 1x6 matrix b
58 ldr(s4, x1), // ldr s4, [x1]
59 add(x1, x1, x4), // add x1, x1, x4
60
61 // Calculate first column of c
62 fmla(v25, t4s, v0, t4s, v4, 0), // fmla v25.4s, v0.4s, v4.s[0]
63 fmla(v26, t4s, v1, t4s, v4, 0), // fmla v26.4s, v1.4s, v4.s[0]
64 fmla(v27, t4s, v2, t4s, v4, 0), // fmla v27.4s, v2.4s, v4.s[0]
65 fmla(v28, t4s, v3, t4s, v4, 0), // fmla v28.4s, v3.4s, v4.s[0]
66
67
68 // Load second element from the 1x6 matrix b
69 ldr(s4, x1), // ldr s4, [x1]
70 add(x1, x1, x4), // add x1, x1, x4
71
72 // Calculate second column of c
73 fmla(v17, t4s, v0, t4s, v4, 0), // fmla v17.4s, v0.4s, v4.s[0]
74 fmla(v18, t4s, v1, t4s, v4, 0), // fmla v18.4s, v1.4s, v4.s[0]
75 fmla(v19, t4s, v2, t4s, v4, 0), // fmla v19.4s, v2.4s, v4.s[0]
76 fmla(v20, t4s, v3, t4s, v4, 0), // fmla v20.4s, v3.4s, v4.s[0]
77
78
79 // Load third element from the 1x6 matrix b
80 ldr(s4, x1), // ldr s4, [x1]
81 add(x1, x1, x4), // add x1, x1, x4
82
83 // Calculated third column of c
84 fmla(v21, t4s, v0, t4s, v4, 0), // fmla v21.4s, v0.4s, v4.s[0]
85 fmla(v22, t4s, v1, t4s, v4, 0), // fmla v22.4s, v1.4s, v4.s[0]
86 fmla(v23, t4s, v2, t4s, v4, 0), // fmla v23.4s, v2.4s, v4.s[0]
87 fmla(v24, t4s, v3, t4s, v4, 0), // fmla v24.4s, v3.4s, v4.s[0]
88
89
90 // Load fourth element from the 1x6 matrix b
91 ldr(s4, x1), // ldr s4, [x1]
92 add(x1, x1, x4), // add x1, x1, x4
93
94 // Calculate fourth column of c
95 fmla(v5, t4s, v0, t4s, v4, 0), // fmla v5.4s, v0.4s, v4.s[0]
96 fmla(v6, t4s, v1, t4s, v4, 0), // fmla v6.4s, v1.4s, v4.s[0]
97 fmla(v7, t4s, v2, t4s, v4, 0), // fmla v7.4s, v2.4s, v4.s[0]
98 fmla(v8, t4s, v3, t4s, v4, 0), // fmla v8.4s, v3.4s, v4.s[0]
99
100
101 // Load fifth element from the 1x6 matrix b
102 ldr(s4, x1), // ldr s4, [x1]
103 add(x1, x1, x4), // add x1, x1, x4
104
105 // Calculate fifth column of c
106 fmla(v9, t4s, v0, t4s, v4, 0), // fmla v9.4s, v0.4s, v4.s[0]
107 fmla(v10, t4s, v1, t4s, v4, 0), // fmla v10.4s, v1.4s, v4.s[0]
108 fmla(v11, t4s, v2, t4s, v4, 0), // fmla v11.4s, v2.4s, v4.s[0]
109 fmla(v12, t4s, v3, t4s, v4, 0), // fmla v12.4s, v3.4s, v4.s[0]
110
111
112 // Load sixth element from the 1x6 matrix b
113 ldr(s4, x1), // ldr s4, [x1]
114 add(x1, x1, x4), // add x1, x1, x4
115
116 // Calculated sixth column of c
117 fmla(v13, t4s, v0, t4s, v4, 0), // fmla v13.4s, v0.4s, v4.s[0]
118 fmla(v14, t4s, v1, t4s, v4, 0), // fmla v14.4s, v1.4s, v4.s[0]
119 fmla(v15, t4s, v2, t4s, v4, 0), // fmla v15.4s, v2.4s, v4.s[0]
120 fmla(v16, t4s, v3, t4s, v4, 0), // fmla v16.4s, v3.4s, v4.s[0]
121
122
123 // offset x6 to the next element in the column
124 add(x6, x6, 4), // add x6, x6, #4 // #4 = sizeof(float)
125
126 // Restore x1 to be incremented again
127 mov(x1, x6), // mov x1, x6
128
129 // Loop back
130 cbnz(x9, -40*4), // cbnz x9, matmul_loop_over_K
131
132 // Restore initial value of x2 that was changed by the loads
133 mov(x2, x7), // mov x2, x7
134
135 // Store first column back to memory
136 st1Post(v25, t4s, v26, t4s, v27, t4s, v28, t4s, x2, x5), // st1 {v25.4s, v26.4s, v27.4s, v28.4s}, [x2], x5
137 // Store second column back to memory
138 st1Post(v17, t4s, v18, t4s, v19, t4s, v20, t4s, x2, x5), // st1 {v17.4s, v18.4s, v19.4s, v20.4s}, [x2], x5
139 // Store third column back to memory
140 st1Post(v21, t4s, v22, t4s, v23, t4s, v24, t4s, x2, x5), // st1 {v21.4s, v22.4s, v23.4s, v24.4s}, [x2], x5
141 // Store fourth column back to memory
142 st1Post(v5, t4s, v6, t4s, v7, t4s, v8, t4s, x2, x5), // st1 {v5.4s, v6.4s, v7.4s, v8.4s}, [x2], x5
143 // Store fifth column back to memory
144 st1Post(v9, t4s, v10, t4s, v11, t4s, v12, t4s, x2, x5), // st1 {v9.4s, v10.4s, v11.4s, v12.4s}, [x2], x5
145 // Store sixth column back to memory
146 st1Post(v13, t4s, v14, t4s, v15, t4s, v16, t4s, x2, x5), // st1 {v13.4s, v14.4s, v15.4s, v16.4s}, [x2], x5
147
148 // Procedural Call Standard
149 // restore callee-saved registers
150 ldpPost(d14, d15, sp, 16), // ldp d14, d15, [sp], #16
151 ldpPost(d12, d13, sp, 16), // ldp d12, d13, [sp], #16
152 ldpPost(d10, d11, sp, 16), // ldp d10, d11, [sp], #16
153 ldpPost(d8, d9, sp, 16), // ldp d8, d9, [sp], #16
154
155 ldpPost(x27, x28, sp, 16), // ldp x27, x28, [sp], #16
156 ldpPost(x25, x26, sp, 16), // ldp x25, x26, [sp], #16
157 ldpPost(x23, x24, sp, 16), // ldp x23, x24, [sp], #16
158 ldpPost(x21, x22, sp, 16), // ldp x21, x22, [sp], #16
159 ldpPost(x19, x20, sp, 16), // ldp x19, x20, [sp], #16
160
161 // restore frame pointer and link register
162 ldpPost(fp, lr, sp, 16), // ldp fp, lr, [sp], #16
163
164 ret() // ret
165 });
166
167 kernel.write("matmul_16_6_k.bin");
168}
Looking at the first highlight. This is our adjusted instruction at runtime, using the given loop count for the k dimension.
movz(x9, k_loop), // mov x9, "iterator for K loop"
Another interesting instruction, is the second highlight. We need to manually calculate the offset. Which in our case, we jump 40 instructions to loop again.
cbnz(x9, -40*4), // cbnz x9, matmul_loop_over_K
3. Test the kernel generation. Report performance in GFLOPS¶
Testing our jitted kernel, we get the same performance as out previous implementation.
Note
The generation of the matmul kernel is done outside of the benchmarking loop, as one would do in a real world scenario.
------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
------------------------------------------------------------------------------------------------------------------------------------
GemmJited16x6x1Fixture/BM_jited_matmul_16_6_1/min_warmup_time:1.000_mean 5.57 ns 5.56 ns 10 34.5601G/s
GemmJited16x6x1Fixture/BM_jited_matmul_16_6_1/min_warmup_time:1.000_median 5.56 ns 5.55 ns 10 34.6245G/s
GemmJited16x6x1Fixture/BM_jited_matmul_16_6_1/min_warmup_time:1.000_stddev 0.041 ns 0.040 ns 10 249.138M/s
GemmJited16x6x1Fixture/BM_jited_matmul_16_6_1/min_warmup_time:1.000_cv 0.73 % 0.72 % 10 0.72%
GemmJited16x6x128Fixture/BM_jited_matmul_16_6_128/min_warmup_time:1.000_mean 187 ns 187 ns 10 131.579G/s
GemmJited16x6x128Fixture/BM_jited_matmul_16_6_128/min_warmup_time:1.000_median 187 ns 186 ns 10 131.811G/s
GemmJited16x6x128Fixture/BM_jited_matmul_16_6_128/min_warmup_time:1.000_stddev 1.02 ns 1.01 ns 10 702.935M/s
GemmJited16x6x128Fixture/BM_jited_matmul_16_6_128/min_warmup_time:1.000_cv 0.54 % 0.54 % 10 0.53%
jited_matmul_16_6_1 kernel: \(34.6\) GFLOPS
jited_matmul_16_6_k(=128) kernel: \(131.6\) GFLOPS