Submission 2025-05-29¶
This section implements a backend for binary tensor contractions and unary tensor permutations. The backend performs the provided tensor operation exactly as defined by the interface and does not optimize it. Contractions are executed as recursive loops over small GEMM or Batch-Reduce GEMM (BRGEMM) kernels. Permutations are executed as recursive loops over small transposition kernels.
User Interface¶
1. Begin implementing the setup function of the class einsum::backend::TensorOperation for binary tensor contractions¶
File: TensorOperation.cpp
Before generating any kernel, we make sure that all necessary conditions are met. Therefore, we have created a number of checks to verify that the input configuration for the tensor operation is correct and executable.
mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(dtype_t dtype, prim_t prim_first_touch, prim_t prim_main,
prim_t prim_last_touch, std::span<const dim_t> dim_types,
std::span<const exec_t> exec_types, std::span<const int64_t> dim_sizes,
std::span<const int64_t> strides_in0,
std::span<const int64_t> strides_in1,
std::span<const int64_t> strides_out)
{
hasSetupError = false;
indexPrimBatch = -1;
indexPrimK = -1;
indexPrimM = -1;
indexPrimN = -1;
// Validate dimensions
if (dim_sizes.size() != dim_types.size() || dim_sizes.empty() || dim_types.empty()) {...}
if (!(strides_in0.size() == dim_sizes.size() && strides_out.size() == dim_sizes.size() && (strides_in1.size() == dim_sizes.size()
// strides_in1 can be empty for unary operations
|| ((isUnary(prim_first_touch)
|| prim_first_touch == prim_t::none) && (isUnary(prim_main) || prim_main == prim_t::none) && (isUnary(prim_last_touch)
|| prim_last_touch == prim_t::none) && strides_in1.empty())))) {...}
for (exec_t exec : exec_types) { if (exec == exec_t::shared) {...} }
// Validate dtype types - currently only fp32 is supported
if (dtype != dtype_t::fp32) {...}
if (!isSortedConfiguration(exec_types)) {...}
if (!isValidPrimConfig(dim_types, exec_types, strides_in0, strides_out)) {...}
if (!isValidKDim(dim_types, exec_types, strides_in1, prim_main)) {...}
if (isUnary(prim_main)) { if (!isValidStride(dim_types, strides_in0, stride_t::out) || !isValidStride(dim_types, strides_out, stride_t::out)) {...} }
else if (isBrgemm(prim_main)) { if (!isValidStride(dim_types, strides_in0, stride_t::in0)
|| !isValidStride(dim_types, strides_in1, stride_t::in1)
|| !isValidStride(dim_types, strides_out, stride_t::out)) {...} }
else if (prim_main == prim_t::none) { /* Do nothing */ }
else { release_assert(false, "Unexpected value for the main primitive"); }
// Validated through isValidPrimConfig that these indices exists
indexPrimM = findMatch(dim_types, exec_types, dim_t::m, exec_t::prim);
indexPrimN = findMatch(dim_types, exec_types, dim_t::n, exec_t::prim);
release_assert(indexPrimM != -1, "Expected a valid index for the M dimension but found none.");
release_assert(indexPrimN != -1, "Expected a valid index for the N dimension but found none.");
Possible errors that can occur during the configuration verification step are:
err_wrong_dtype,
err_wrong_dimension,
err_wrong_primitive,
err_wrong_first_touch_primitive,
err_wrong_main_primitive,
err_wrong_last_touch_primitive,
err_execution_type_not_supported,
err_invalid_primitive_configuration,
err_invalid_first_touch_configuration,
err_invalid_main_configuration0,
err_invalid_last_touch_configuration,
err_invalid_execution_order,
err_invalid_strides,
If the verification step is successful, we check whether prim_first_touch
, prim_main
, and prim_last_touch
are defined. If so, we create the corresponding kernel.
prim_first_touch
and prim_last_touch
are restricted to unary operations, but prim_main
can be either a unary or a GEMM or BRGEMM.
if (prim_first_touch != prim_t::none) {...}
if (prim_main != prim_t::none)
{
if (isBrgemm(prim_main)) {...}
else if (isUnary(prim_main)) {...}
}
if (prim_last_touch != prim_t::none) {...}
return error_t::success;
}
Recursive Loops Over Primitives¶
1. Implement the execute function of the einsum::backend::TensorOperation class using recursive loops over primitives¶
The execute
function is used to perform the configured tensor operation on two or three input tensors. Since we also support tensor
operations consisting of only a unary, the second input tensor is not always necessary. We parse the input tensors and call the actual
executer function, execute_dimension
.
void mini_jit::TensorOperation::execute(void const *tensor_in0, void const *tensor_in1, void *tensor_out)
{
release_assert(hasSetupError != true, "The setup resulted in a error, do not execute the setup");
release_assert(tensor_in0 != nullptr, "The tensor_in0 parameter is a nullptr, but should be a valid pointer to memory.");
release_assert(tensor_out != nullptr, "The tensor_out parameter is a nullptr, but should be a valid pointer to memory.");
if (isBrgemm(prim_main))
{
release_assert(tensor_in1 != nullptr, "The tensor_in1 parameter is a nullptr, but should be a valid pointer to memory");
}
char const *ptr_in0 = static_cast<char const *>(tensor_in0);
char const *ptr_in1 = static_cast<char const *>(tensor_in1);
char *ptr_out = static_cast<char *>(tensor_out);
execute_dimension(0, ptr_in0, ptr_in1, ptr_out, true, true);
}
execute_dimension
has three main tasks. First, if defined, check whether the prim_first_touch
or prim_last_touch
primitive
should be executed on the output pointer. Second, if there are outer loops, meaning the tensors have a dimension greater than the dimension
of the used primitive kernel, run a loop over those dimensions until the primitive kernel inside that loop can be called. Third, if there
are no higher dimensions left for iteration, execute the primitive kernels in the correct order.
Compute the first_access
and last_access
and check if higher dimensions are present. If so, execute recursively:
void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const *ptr_in0, char const *ptr_in1, char *ptr_out,
bool first_access, bool last_access)
{
uint32_t dtype_bytes = 4;
int64_t dim_size = dim_sizes[index_dim];
int64_t stride_in0 = strides_in0[index_dim];
int64_t stride_in1 = isUnary(prim_main) ? 1 : strides_in1[index_dim];
int64_t stride_out = strides_out[index_dim];
// std::cout << "Execute check " << index_dim + 1 << " " << std::endl;
if (exec_types[index_dim] == exec_t::seq)
{
release_assert(exec_types[index_dim] == exec_t::seq, "Expected a sequential loop");
bool is_first = first_access;
bool is_last = last_access;
for (int64_t iDim = 0; iDim < dim_size; iDim++)
{
if (dim_types[index_dim] == dim_t::k)
{
is_first = first_access && (iDim == 0);
is_last = last_access && (iDim == (dim_size - 1));
}
char const *rec_ptr_in0 = ptr_in0 + iDim * stride_in0 * dtype_bytes;
char const *rec_ptr_in1 = ptr_in1 + iDim * stride_in1 * dtype_bytes;
char *rec_ptr_out = ptr_out + iDim * stride_out * dtype_bytes;
execute_dimension(index_dim + 1, rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
}
}
If no higher dimension is left for iteration, call the primitive kernels:
else
{
release_assert(exec_types[index_dim] == exec_t::prim, "Expected a primitive loop");
// call first touch kernel if necessary
if (first_access && prim_first != prim_t::none) {...}
// call main_kernel kernel
if (prim_main != prim_t::none)
{
if (std::holds_alternative<Unary>(main_kernel)) {...}
else if (std::holds_alternative<Brgemm>(main_kernel)) {...}
else {...} // error case
}
// call last touch kernel if necessary
if (last_access && prim_last != prim_t::none) {...}
}
2. Verify your implementation against a reference implementation¶
We implemented the following tests to verify the functionality of our TensorOperation.cpp
when performing the first, main, and last
primitives in combination with a naive version. The tests are located in the following file: TensorOperation.test.cpp
.
// // without outer dimensions
TEST_CASE("Test tensor operation with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with main kernel: gemm", "[tensor_operation][gemm][correctness]")
TEST_CASE("Test tensor operation with main kernel: brgemm", "[tensor_operation][brgemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: gemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with last touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with first touch: unary (zero, relu, copy) & main kernel: brgemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with main kernel: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
// with outer dimensions
TEST_CASE("Test tensor operation with outer loop with main kernel: gemm", "[tensor_operation][gemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with main kernel: brgemm", "[tensor_operation][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy)", "[tensor_operation][unary][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy) & main kernel: gemm", "[tensor_operation][unary][gemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: gemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with last touch: unary (zero, relu, copy) & main kernel: brgemm", "[tensor_operation][unary][brgemm][correctness]")
TEST_CASE("Test tensor operation with outer loop with first touch: unary (zero, relu, copy) & main kernel: brgemm & last touch: unary (zero, relu, copy)", "[tensor_operation][unary][brgemm][correctness]")
Performance Benchmarking¶
1. Benchmark the performance of your implementation and report the measured performance in GFLOPS¶
Tensor contraction using the GEMM primitive:
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_mean 4359838 ns 4343934 ns 10 123.593G/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_median 4361667 ns 4344882 ns 10 123.564G/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_stddev 17304 ns 17543 ns 10 500.82M/s
BM_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:0/min_warmup_time:0.300_cv 0.40 % 0.40 % 10 0.41%
Tensor contraction using the BRGEMM primitive:
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_mean 4365885 ns 4350242 ns 10 123.413G/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_median 4361928 ns 4346152 ns 10 123.528G/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_stddev 14186 ns 14016 ns 10 396.45M/s
BM_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:1/min_warmup_time:0.300_cv 0.32 % 0.32 % 10 0.32%
Tensor contraction using the Zero, BRGEMM and ReLU primitives:
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_mean 4464672 ns 4448666 ns 10 120.682G/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_median 4461153 ns 4444776 ns 10 120.787G/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_stddev 14498 ns 14307 ns 10 387.2M/s
BM_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:2/min_warmup_time:0.300_cv 0.32 % 0.32 % 10 0.32%
2. Design your own setups. Which setups achieve a high performance and which setups are slow¶
First: Zero & Main: BRGEMM
A: 262144, B: 262144, C: 1048576
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_mean 4449301 ns 4433374 ns 10 121.098G/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_median 4448818 ns 4433182 ns 10 121.103G/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_stddev 8350 ns 7959 ns 10 217.4M/s
BM_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:3/min_warmup_time:0.300_cv 0.19 % 0.18 % 10 0.18%
Last: Relu
A: 8388608, B: 8192, C: 8388608
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_mean 1694290 ns 1685602 ns 10 9.95364G/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_median 1693287 ns 1685075 ns 10 9.95636G/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_stddev 11637 ns 11124 ns 10 65.7127M/s
BM_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:4/min_warmup_time:0.300_cv 0.69 % 0.66 % 10 0.66%
Main: BRGEMM & Last: RELU
A: 262144, B: 262144, C: 1048576
Poor performance due to memory bound
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_mean 4474456 ns 4458350 ns 10 120.42G/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_median 4476878 ns 4460413 ns 10 120.364G/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_stddev 9309 ns 9001 ns 10 243.248M/s
BM_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:5/min_warmup_time:0.300_cv 0.21 % 0.20 % 10 0.20%
Main: BRGEMM & Last: RELU
A: 524288, B: 524288, C: 1048576
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-----------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_mean 8660603 ns 8629735 ns 10 124.424G/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_median 8651362 ns 8620884 ns 10 124.551G/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_stddev 15382 ns 15092 ns 10 217.397M/s
BM_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:6/min_warmup_time:0.300_cv 0.18 % 0.17 % 10 0.17%