Submission 2025-05-29

This section implements a backend for binary tensor contractions and unary tensor permutations. The backend performs the provided tensor operation exactly as defined by the interface and does not optimize it. Contractions are executed as recursive loops over small GEMM or Batch-Reduce GEMM (BRGEMM) kernels. Permutations are executed as recursive loops over small transposition kernels.

User Interface

1. Begin implementing the setup function of the class einsum::backend::TensorOperation for binary tensor contractions

File: TensorOperation.cpp

Before generating any kernel, we make sure that all necessary conditions are met. Therefore, we have created a number of checks to verify that the input configuration for the tensor operation is correct and executable.

mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtype, prim_t prim_first_touch, prim_t prim_main,
                                                                    prim_t prim_last_touch, std::span<const dim_t> dim_types,
                                                                    std::span<const exec_t> exec_types, std::span<const int64_t> dim_sizes,
                                                                    std::span<const int64_t> strides_in0,
                                                                    std::span<const int64_t> strides_in1,
                                                                    std::span<const int64_t> strides_out)
{
hasSetupError = false;
indexPrimBatch = -1;
indexPrimK = -1;
indexPrimM = -1;
indexPrimN = -1;

// Validate dimensions
if (dim_sizes.size() != dim_types.size() || dim_sizes.empty() || dim_types.empty()) {...}

if (!(strides_in0.size() == dim_sizes.size() && strides_out.size() == dim_sizes.size() && (strides_in1.size() == dim_sizes.size()
    // strides_in1 can be empty for unary operations
    || ((isUnary(prim_first_touch)
    || prim_first_touch == prim_t::none) && (isUnary(prim_main) || prim_main == prim_t::none) && (isUnary(prim_last_touch)
    || prim_last_touch == prim_t::none) && strides_in1.empty())))) {...}

for (exec_t exec : exec_types) { if (exec == exec_t::shared) {...} }

// Validate dtype types - currently only fp32 is supported
if (dtype != dtype_t::fp32) {...}

if (!isSortedConfiguration(exec_types)) {...}

if (!isValidPrimConfig(dim_types, exec_types, strides_in0, strides_out)) {...}

if (!isValidKDim(dim_types, exec_types, strides_in1, prim_main)) {...}

if (isUnary(prim_main)) { if (!isValidStride(dim_types, strides_in0, stride_t::out) || !isValidStride(dim_types, strides_out, stride_t::out)) {...} }
else if (isBrgemm(prim_main)) { if (!isValidStride(dim_types, strides_in0, stride_t::in0)
                                   || !isValidStride(dim_types, strides_in1, stride_t::in1)
                                   || !isValidStride(dim_types, strides_out, stride_t::out)) {...} }
else if (prim_main == prim_t::none) { /* Do nothing */ }
else { release_assert(false, "Unexpected value for the main primitive"); }

// Validated through isValidPrimConfig that these indices exists
indexPrimM = findMatch(dim_types, exec_types, dim_t::m, exec_t::prim);
indexPrimN = findMatch(dim_types, exec_types, dim_t::n, exec_t::prim);

release_assert(indexPrimM != -1, "Expected a valid index for the M dimension but found none.");
release_assert(indexPrimN != -1, "Expected a valid index for the N dimension but found none.");

Possible errors that can occur during the configuration verification step are:

err_wrong_dtype,
err_wrong_dimension,
err_wrong_primitive,
err_wrong_first_touch_primitive,
err_wrong_main_primitive,
err_wrong_last_touch_primitive,
err_execution_type_not_supported,
err_invalid_primitive_configuration,
err_invalid_first_touch_configuration,
err_invalid_main_configuration0,
err_invalid_last_touch_configuration,
err_invalid_execution_order,
err_invalid_strides,

If the verification step is successful, we check whether prim_first_touch, prim_main, and prim_last_touch are defined. If so, we create the corresponding kernel. prim_first_touch and prim_last_touch are restricted to unary operations, but prim_main can be either a unary or a GEMM or BRGEMM.

if (prim_first_touch != prim_t::none) {...}

if (prim_main != prim_t::none)
{
    if (isBrgemm(prim_main)) {...}
    else if (isUnary(prim_main)) {...}
}

if (prim_last_touch != prim_t::none) {...}

return error_t::success;
}

Recursive Loops Over Primitives

1. Implement the execute function of the einsum::backend::TensorOperation class using recursive loops over primitives

The execute function is used to perform the configured tensor operation on two or three input tensors. Since we also support tensor operations consisting of only a unary, the second input tensor is not always necessary. We parse the input tensors and call the actual executer function, execute_dimension.

void mini_jit::TensorOperation::execute(void const *tensor_in0, void const *tensor_in1, void *tensor_out)
{
release_assert(hasSetupError != true, "The setup resulted in a error, do not execute the setup");
release_assert(tensor_in0 != nullptr, "The tensor_in0 parameter is a nullptr, but should be a valid pointer to memory.");
release_assert(tensor_out != nullptr, "The tensor_out parameter is a nullptr, but should be a valid pointer to memory.");

if (isBrgemm(prim_main))
{
    release_assert(tensor_in1 != nullptr, "The tensor_in1 parameter is a nullptr, but should be a valid pointer to memory");
}

char const *ptr_in0 = static_cast<char const *>(tensor_in0);
char const *ptr_in1 = static_cast<char const *>(tensor_in1);
char *ptr_out = static_cast<char *>(tensor_out);

execute_dimension(0, ptr_in0, ptr_in1, ptr_out, true, true);
}

execute_dimension has three main tasks. First, if defined, check whether the prim_first_touch or prim_last_touch primitive should be executed on the output pointer. Second, if there are outer loops, meaning the tensors have a dimension greater than the dimension of the used primitive kernel, run a loop over those dimensions until the primitive kernel inside that loop can be called. Third, if there are no higher dimensions left for iteration, execute the primitive kernels in the correct order.

Compute the first_access and last_access and check if higher dimensions are present. If so, execute recursively:

void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const *ptr_in0, char const *ptr_in1, char *ptr_out,
                                              bool first_access, bool last_access)
{
uint32_t dtype_bytes = 4;
int64_t dim_size = dim_sizes[index_dim];
int64_t stride_in0 = strides_in0[index_dim];
int64_t stride_in1 = isUnary(prim_main) ? 1 : strides_in1[index_dim];
int64_t stride_out = strides_out[index_dim];

// std::cout << "Execute check " << index_dim + 1 << " " << std::endl;
if (exec_types[index_dim] == exec_t::seq)
{
    release_assert(exec_types[index_dim] == exec_t::seq, "Expected a sequential loop");

    bool is_first = first_access;
    bool is_last = last_access;

    for (int64_t iDim = 0; iDim < dim_size; iDim++)
    {
    if (dim_types[index_dim] == dim_t::k)
    {
        is_first = first_access && (iDim == 0);
        is_last = last_access && (iDim == (dim_size - 1));
    }

    char const *rec_ptr_in0 = ptr_in0 + iDim * stride_in0 * dtype_bytes;
    char const *rec_ptr_in1 = ptr_in1 + iDim * stride_in1 * dtype_bytes;
    char *rec_ptr_out = ptr_out + iDim * stride_out * dtype_bytes;
    execute_dimension(index_dim + 1, rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
    }
}

If no higher dimension is left for iteration, call the primitive kernels:

else
{
    release_assert(exec_types[index_dim] == exec_t::prim, "Expected a primitive loop");

    // call first touch kernel if necessary
    if (first_access && prim_first != prim_t::none) {...}

    // call main_kernel kernel
    if (prim_main != prim_t::none)
    {
        if (std::holds_alternative<Unary>(main_kernel)) {...}
        else if (std::holds_alternative<Brgemm>(main_kernel)) {...}
        else {...} // error case
    }

    // call last touch kernel if necessary
    if (last_access && prim_last != prim_t::none) {...}
}

2. Verify your implementation against a reference implementation

We implemented the following tests to verify the functionality of our TensorOperation.cpp when performing the first, main, and last primitives in combination with a naive version. The tests are located in the following file: TensorOperation.test.cpp.

// // without outer dimensions
TEST_CASE("Test tensor operation with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][gemm][correctness]")
TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][brgemm][correctness]")

TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")

TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: gemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: brgemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")

// with outer dimensions
TEST_CASE("Test tensor operation with outer loop with main kernel: gemm", "[tensor_operation][gemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with main kernel: brgemm", "[tensor_operation][brgemm][correctness]")

TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")

TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: gemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: brgemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")

Performance Benchmarking

1. Benchmark the performance of your implementation and report the measured performance in GFLOPS

Tensor contraction using the GEMM primitive:

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_mean                  4359838 ns      4343934 ns           10 123.593G/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_median                4361667 ns      4344882 ns           10 123.564G/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_stddev                  17304 ns        17543 ns           10  500.82M/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_cv                       0.40 %          0.40 %            10      0.41%

Tensor contraction using the BRGEMM primitive:

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_mean                4365885 ns      4350242 ns           10 123.413G/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_median              4361928 ns      4346152 ns           10 123.528G/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_stddev                14186 ns        14016 ns           10  396.45M/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_cv                     0.32 %          0.32 %            10      0.32%

Tensor contraction using the Zero, BRGEMM and ReLU primitives:

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_mean      4464672 ns      4448666 ns           10 120.682G/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_median    4461153 ns      4444776 ns           10 120.787G/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_stddev      14498 ns        14307 ns           10   387.2M/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_cv           0.32 %          0.32 %            10      0.32%

2. Design your own setups. Which setups achieve a high performance and which setups are slow

  • First: Zero & Main: BRGEMM

  • A: 262144, B: 262144, C: 1048576

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_mean           4449301 ns      4433374 ns           10 121.098G/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_median         4448818 ns      4433182 ns           10 121.103G/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_stddev            8350 ns         7959 ns           10   217.4M/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_cv                0.19 %          0.18 %            10      0.18%
  • Last: Relu

  • A: 8388608, B: 8192, C: 8388608

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_mean                   1694290 ns      1685602 ns           10 9.95364G/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_median                 1693287 ns      1685075 ns           10 9.95636G/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_stddev                   11637 ns        11124 ns           10 65.7127M/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_cv                        0.69 %          0.66 %            10      0.66%
  • Main: BRGEMM & Last: RELU

  • A: 262144, B: 262144, C: 1048576

  • Poor performance due to memory bound

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_mean           4474456 ns      4458350 ns           10  120.42G/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_median         4476878 ns      4460413 ns           10 120.364G/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_stddev            9309 ns         9001 ns           10 243.248M/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_cv                0.21 %          0.20 %            10      0.20%
  • Main: BRGEMM & Last: RELU

  • A: 524288, B: 524288, C: 1048576

-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                            Time             CPU   Iterations      FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_mean           8660603 ns      8629735 ns           10 124.424G/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_median         8651362 ns      8620884 ns           10 124.551G/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_stddev           15382 ns        15092 ns           10 217.397M/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_cv                0.18 %          0.17 %            10      0.17%