Submission 2025-06-05

Shared Memory Parallelization

In the shared memory domain, loops can be parallelized at any point within the nested loop structure. However, to simplify the implementation, we only parallelize the outermost loops. In other words, we do not parallelize loops that are nested inside sequential loops.

1. Implement the function execute_iter_parallel

File: TensorOperation.cpp

To enable our tensor operations to be processed in parallel, we now accept shared as an execution type. In the setup, we check if an execution type of shared exists. Additionally, we ensure that the k dimensions are not shared.

// Check if shared exists and set parallel flag
for (exec_t exec : exec_types)
{
    if (exec == exec_t::shared)
    {
    isParallel = true;
    }
}

if (isParallel)
{
    // K dimension must not be shared
    int32_t kDimExecType = findMatch(dim_types, exec_types, dim_t::k, exec_t::shared);
    if (kDimExecType != -1)
    {
    hasSetupError = true;
    return error_t::err_k_dimension_must_not_be_shared;
    }
}

Lastly, we check if the execution types are sorted in the correct order: first shared, then sequential, and finally primitive.

bool mini_jit::TensorOperation::isSortedConfiguration(const std::span<const exec_t> &exec)
{
bool seenSequential = false;
bool seenPrimitive = false;
for (exec_t exec_type : exec)
{
    if (exec_type == exec_t::shared && !seenSequential && !seenPrimitive)
    {
    // Nothing to do, shared must be first
    }
    else if (exec_type == exec_t::shared && (seenSequential || seenPrimitive))
    {
    return false;
    }
    else if (exec_type == exec_t::seq && !seenPrimitive)
    {
    seenSequential = true;
    }
    else if (exec_type == exec_t::seq && seenPrimitive)
    {
    return false;
    }
    else if (exec_type == exec_t::prim)
    {
    seenPrimitive = true;
    }
}

return true;
}

The benchmark results of the serial tensor operations had a peak performance of around \(120\) GFLOPS. Now, we benchmark using OpenMP and 4 threads, resulting in performance measurements around \(420\) GFLOPS.

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                                               Time             CPU   Iterations      FLOPS
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_parallel_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:7/min_warmup_time:0.300/threads:4_mean                  5201950 ns      1292865 ns           10 415.261G/s
BM_parallel_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:7/min_warmup_time:0.300/threads:4_median                5193611 ns      1291863 ns           10 415.579G/s
BM_parallel_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:7/min_warmup_time:0.300/threads:4_stddev                  32185 ns         4344 ns           10 1.39347G/s
BM_parallel_tensor_GEMM/size_a:262144/size_b:262144/size_c:1048576/config:7/min_warmup_time:0.300/threads:4_cv                       0.62 %          0.34 %            10      0.34%
BM_parallel_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:8/min_warmup_time:0.300/threads:4_mean                5195357 ns      1287333 ns           10 417.045G/s
BM_parallel_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:8/min_warmup_time:0.300/threads:4_median              5167859 ns      1287433 ns           10 417.009G/s
BM_parallel_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:8/min_warmup_time:0.300/threads:4_stddev                80687 ns         4319 ns           10 1.39959G/s
BM_parallel_tensor_BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:8/min_warmup_time:0.300/threads:4_cv                     1.55 %          0.34 %            10      0.34%
BM_parallel_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:9/min_warmup_time:0.300/threads:4_mean      5577549 ns      1313489 ns           10 408.757G/s
BM_parallel_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:9/min_warmup_time:0.300/threads:4_median    5491313 ns      1310114 ns           10 409.789G/s
BM_parallel_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:9/min_warmup_time:0.300/threads:4_stddev     353091 ns         9804 ns           10 3.03171G/s
BM_parallel_tensor_Zero+BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:9/min_warmup_time:0.300/threads:4_cv           6.33 %          0.75 %            10      0.74%
BM_parallel_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:10/min_warmup_time:0.300/threads:4_mean          5336436 ns      1295288 ns           10 414.481G/s
BM_parallel_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:10/min_warmup_time:0.300/threads:4_median        5306927 ns      1295453 ns           10 414.427G/s
BM_parallel_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:10/min_warmup_time:0.300/threads:4_stddev          95431 ns         1975 ns           10  632.06M/s
BM_parallel_tensor_Zero+BRGEMM/size_a:262144/size_b:262144/size_c:1048576/config:10/min_warmup_time:0.300/threads:4_cv               1.79 %          0.15 %            10      0.15%
BM_parallel_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:11/min_warmup_time:0.300/threads:4_mean                  2954501 ns       735408 ns           10 22.8172G/s
BM_parallel_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:11/min_warmup_time:0.300/threads:4_median                2947921 ns       735807 ns           10 22.8011G/s
BM_parallel_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:11/min_warmup_time:0.300/threads:4_stddev                  55255 ns         9959 ns           10 307.823M/s
BM_parallel_tensor_Relu/size_a:8388608/size_b:8192/size_c:8388608/config:11/min_warmup_time:0.300/threads:4_cv                       1.87 %          1.35 %            10      1.35%
BM_parallel_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:12/min_warmup_time:0.300/threads:4_mean          5243909 ns      1301545 ns           10 412.507G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:12/min_warmup_time:0.300/threads:4_median        5239656 ns      1299425 ns           10 413.161G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:12/min_warmup_time:0.300/threads:4_stddev          35856 ns         9430 ns           10 2.98182G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:262144/size_b:262144/size_c:1048576/config:12/min_warmup_time:0.300/threads:4_cv               0.68 %          0.72 %            10      0.72%
BM_parallel_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:13/min_warmup_time:0.300/threads:4_mean         10136019 ns      2524142 ns           10 425.392G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:13/min_warmup_time:0.300/threads:4_median       10143290 ns      2523724 ns           10 425.459G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:13/min_warmup_time:0.300/threads:4_stddev          59898 ns         7583 ns           10 1.27538G/s
BM_parallel_tensor_BRGEMM+RELU/size_a:524288/size_b:524288/size_c:1048576/config:13/min_warmup_time:0.300/threads:4_cv               0.59 %          0.30 %            10      0.30%

Optimization Passes

1. IR that supports transformations

We created a struct TensorConfig in TensorConfig.h to support transformations and optimization passes on our tensor operation. This configuration contains all the input data for our tensor operation. Before handing this configuration over to our tensor operation setup, we run our optimization passes over it. We also added a equal(const TensorConfig &config1, const TensorConfig config2) and to_string() method for testing purposes.

2. Implement optimization passes

Dimension Reordering Fusing

We added dimension reordering to our optimization passes to improve dimension fusion. The idea is to move any dimension X next to dimension Y if they are the same type and the Stride(X) = |Y| * Stride(Y) condition is met.

void mini_jit::TensorOptimization::_dimension_reordering_fusing(TensorConfig &config)

Dimension Splitting

We added dimension splitting to our optimization passes. The idea is to check if any dimension is greater than or equal to 256. If so, we split the dimension into two, starting at the floor of the square root of the dimension size, and check if it is a dominator. Otherwise, we decrement the possible dominator and test until it is 2. If a dominator is found, the dimension is split.

void mini_jit::TensorOptimization::_dimension_splitting(TensorConfig &config)

Dimension Fusing

We added dimension fusion to our optimization passes. The idea is to check if two neighboring dimensions have the same dimension type and if the product of both dimension sizes is less than or equal to 256. We also check if the condition Stride(X) = |Y| * Stride(Y) is true. If so, we fuse the two dimensions.

void mini_jit::TensorOptimization::_dimension_fusing(TensorConfig &config)

Dimension Reordering Shared

We added dimension reordering to our optimization passes for better shared identification. We reorder sequential loops with other sequential loops and shared loops with other shared loops. We sort by strides but discourage any k-dimensional or repeating dimensions. We sum the strides and divide by eight if it is a k-dimensional stride and divide by two if it is a repeating dimension, excluding the c-dimension.

void mini_jit::TensorOptimization::_dimension_reordering_shared(TensorConfig &config)
{
...
    uint64_t value = (*jStrideIn0 * *jStrideIn0) + (*jStrideIn1 * *jStrideIn1) + (*jStrideOut * *jStrideOut);

    // value/8 if we have a k-dimension
    value >>= (*jDim == TensorConfig::dim_t::k) * 3;

    // value/2 if we have the same dimension type as the last dimension, but not for c dimension
    value >>= (*jDim == previous_dim && *jDim != TensorConfig::dim_t::c) * 1;
...
}

Primitive Identification

We added primitive identification support to our optimization pass. The following rules are applied based on the dimension type: - m-dimension: search m-dimension with a unit-stride in the first input - n-dimension: search in the second input and in the output for the smallest stride - k-dimension: only applies to GEMM or BRGEMM, search for unit–stride in the second input - second-k-dimension: only applies to BRGEMM, search for the smallest stride in first input or second input, but not select the already found k-dimension

Additionally, we do not modify any existing chosen primitives by the user.

void mini_jit::TensorOptimization::_primitive_identification(TensorConfig &config)

Shared Identification

We added shared identification support to our optimization pass. At most, we can convert to shared until the first primitive arises or the first k-dimensional primitive. We only tag as many dimensions as are shared, i.e., if the first dimension is perfectly divisible by the number of OpenMP threads in use, we do not convert any further dimensions as shared. Additionally, we only convert to shared if the unbalanced ratio of the dimensions is greater than 1%. (shared_dimensions_size % thread_count) / shared_dimensions_size < 1%.

void mini_jit::TensorOptimization::_shared_identification(TensorConfig &config)

3. Lower the optimized IR code to your tensor operation backend

Since our IR is the struct TensorConfig, we only need to provide the configuration to our optimization, and then to our tensor operation setup. This order ensures that the optimizer creates a valid configuration for the tensor operation.

mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(const TensorConfig &config)
{
mini_jit::TensorOptimization optimization;
TensorOperation::config = optimization.optimize(config);

return setup_no_optimization(TensorOperation::config.dtype, TensorOperation::config.first_touch, TensorOperation::config.main,
                             TensorOperation::config.last_touch, TensorOperation::config.dim_types, TensorOperation::config.exec_types,
                             TensorOperation::config.dim_sizes, TensorOperation::config.strides_in0, TensorOperation::config.strides_in1,
                             TensorOperation::config.strides_out);
}

Our TensorOptimization ‘s optimize method executes individual optimization passes on the config struct.

mini_jit::TensorConfig mini_jit::TensorOptimization::optimize(TensorConfig config)
{
_dimension_reordering_fusing(config);

_dimension_splitting(config);

_dimension_fusing(config);

_primitive_identification(config);

_dimension_reordering_shared(config);

// Only call shared after reordering it only parallelize the first loops until the first seq k-loops at maximum
_shared_identification(config);
return config;
}

4. Benchmark the performance of your implementation

File: TensorOptimization.bench.cpp

Matrix multiplication example

-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                              Time             CPU   Iterations      FLOPS
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_mean        1316172 ns      1303763 ns           10 411.786G/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_median      1313935 ns      1303515 ns           10 411.864G/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_stddev         7770 ns         1120 ns           10   353.7M/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_cv             0.59 %          0.09 %            10      0.09%

Tensor contraction example

-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark                                                                                                              Time             CPU   Iterations      FLOPS
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_mean      1310327 ns      1295379 ns           10 414.451G/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_median    1307359 ns      1295362 ns           10 414.456G/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_stddev       8579 ns         1229 ns           10 393.184M/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_cv           0.65 %          0.09 %            10      0.09%

5. Demonstrate the capabilities of your optimization passes

We tested our optimization passes in TensorOptimization.test.cpp. One exhaustive test case is shown below. This optimization involves primitive reordering, fusing, primitive identification, and shared identification. In addition to testing the correctness of the tensor configuration after the optimization passes, we also test the correctness of the tensor operation.

TEST_CASE("Test tensor operation with optimization dimension test reordering and fusing", "[tensor_optimization][gemm][correctness]")
{
using namespace mini_jit;

mini_jit::TensorConfig config{
    mini_jit::TensorConfig::prim_t::none,  // first_touch
    mini_jit::TensorConfig::prim_t::gemm,  // main
    mini_jit::TensorConfig::prim_t::none,  // last touch
    {mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k, mini_jit::TensorConfig::dim_t::m, mini_jit::TensorConfig::dim_t::n,
    mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k},  // dim_types
    {mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq,
    mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq},  // exec_types
    {32, 8, 32, 5, 32, 32},                                                                                           // dim_sizes
    {0, 1024, 1, 0, 0, 32},                                                                                           // strides_in0
    {8192, 1024, 0, 8192 * 32, 32, 1},                                                                                // strides_in1
    {1024, 0, 1, 32768, 32, 0},                                                                                       // strides_out
    mini_jit::TensorConfig::dtype_t::fp32,                                                                            // dtype_t
};

mini_jit::TensorConfig expected{
    mini_jit::TensorConfig::prim_t::none,  // first_touch
    mini_jit::TensorConfig::prim_t::gemm,  // main
    mini_jit::TensorConfig::prim_t::none,  // last touch
    {mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k, mini_jit::TensorConfig::dim_t::m, mini_jit::TensorConfig::dim_t::n,
    mini_jit::TensorConfig::dim_t::k},  // dim_types
    {mini_jit::TensorConfig::exec_t::shared, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::prim,
    mini_jit::TensorConfig::exec_t::prim, mini_jit::TensorConfig::exec_t::prim},  // exec_types
    {5 * 32, 8, 32, 32, 32},                                                       // dim_sizes
    {0, 1024, 1, 0, 32},                                                           // strides_in0
    {8192, 1024, 0, 32, 1},                                                        // strides_in1
    {1024, 0, 1, 32, 0},                                                           // strides_out
    mini_jit::TensorConfig::dtype_t::fp32,                                         // dtype_t
};

mini_jit::TensorOperation tensor_op;
TensorOperation::error_t err = tensor_op.setup(config);

INFO(tensor_op.get_config().to_string());

REQUIRE(err == TensorOperation::error_t::success);
REQUIRE_FALSE(mini_jit::TensorConfig::equals(config, tensor_op.get_config()));
REQUIRE(mini_jit::TensorConfig::equals(expected, tensor_op.get_config()));

GenerationTest test(32, 32, 32, 32 * 1 * 32 * 8 * 1 * 1, 32 * 32 * 1 * 8 * 32 * 5, 1 * 32 * 32 * 1 * 32 * 5);
test.SetUp(TestInfill::Random);

tensor_op.execute(test.matrix_a.data(), test.matrix_b.data(), test.matrix_c.data());

for (int64_t i0 = 0; i0 < expected.dim_sizes[0]; i0++)
{
    for (int64_t i1 = 0; i1 < expected.dim_sizes[1]; i1++)
    {
    uint64_t offset_a = i0 * expected.strides_in0[0] + i1 * expected.strides_in0[1];
    uint64_t offset_b = i0 * expected.strides_in1[0] + i1 * expected.strides_in1[1];
    uint64_t offset_c = i0 * expected.strides_out[0] + i1 * expected.strides_out[1];
    test.naive_matmul_M_N_K_Batch(test.matrix_a.data() + offset_a, test.matrix_b.data() + offset_b,
                                    test.matrix_c_verify.data() + offset_c, 32, 32, 32, 32 * 32, 32 * 32);
    }
}

test.verify_matmul(test.matrix_c_verify.data(), test.matrix_c.data(), test.matrix_c.size());
}