Submission 2025-06-05¶
Optimization Passes¶
1. IR that supports transformations¶
We created a struct TensorConfig
in TensorConfig.h
to support transformations and optimization passes on our tensor operation.
This configuration contains all the input data for our tensor operation. Before handing this configuration over to our tensor operation
setup, we run our optimization passes over it. We also added a equal(const TensorConfig &config1, const TensorConfig config2)
and
to_string()
method for testing purposes.
2. Implement optimization passes¶
Dimension Reordering Fusing
We added dimension reordering to our optimization passes to improve dimension fusion.
The idea is to move any dimension X next to dimension Y if they are the same type and the Stride(X) = |Y| * Stride(Y)
condition is met.
void mini_jit::TensorOptimization::_dimension_reordering_fusing(TensorConfig &config)
Dimension Splitting
We added dimension splitting to our optimization passes. The idea is to check if any dimension is greater than or equal to 256. If so, we split the dimension into two, starting at the floor of the square root of the dimension size, and check if it is a dominator. Otherwise, we decrement the possible dominator and test until it is 2. If a dominator is found, the dimension is split.
void mini_jit::TensorOptimization::_dimension_splitting(TensorConfig &config)
Dimension Fusing
We added dimension fusion to our optimization passes. The idea is to check if two neighboring dimensions have the same dimension type and
if the product of both dimension sizes is less than or equal to 256. We also check if the condition Stride(X) = |Y| * Stride(Y)
is true.
If so, we fuse the two dimensions.
void mini_jit::TensorOptimization::_dimension_fusing(TensorConfig &config)
Dimension Reordering Shared
We added dimension reordering to our optimization passes for better shared identification. We reorder sequential loops with other sequential loops and shared loops with other shared loops. We sort by strides but discourage any k-dimensional or repeating dimensions. We sum the strides and divide by eight if it is a k-dimensional stride and divide by two if it is a repeating dimension, excluding the c-dimension.
void mini_jit::TensorOptimization::_dimension_reordering_shared(TensorConfig &config)
{
...
uint64_t value = (*jStrideIn0 * *jStrideIn0) + (*jStrideIn1 * *jStrideIn1) + (*jStrideOut * *jStrideOut);
// value/8 if we have a k-dimension
value >>= (*jDim == TensorConfig::dim_t::k) * 3;
// value/2 if we have the same dimension type as the last dimension, but not for c dimension
value >>= (*jDim == previous_dim && *jDim != TensorConfig::dim_t::c) * 1;
...
}
Primitive Identification
We added primitive identification support to our optimization pass. The following rules are applied based on the dimension type: - m-dimension: search m-dimension with a unit-stride in the first input - n-dimension: search in the second input and in the output for the smallest stride - k-dimension: only applies to GEMM or BRGEMM, search for unit–stride in the second input - second-k-dimension: only applies to BRGEMM, search for the smallest stride in first input or second input, but not select the already found k-dimension
Additionally, we do not modify any existing chosen primitives by the user.
void mini_jit::TensorOptimization::_primitive_identification(TensorConfig &config)
Shared Identification
We added shared identification support to our optimization pass. At most, we can convert to shared until the first primitive arises or the
first k-dimensional primitive. We only tag as many dimensions as are shared, i.e., if the first dimension is perfectly divisible by the
number of OpenMP threads in use, we do not convert any further dimensions as shared. Additionally, we only convert to shared if the
unbalanced ratio of the dimensions is greater than 1%.
(shared_dimensions_size % thread_count) / shared_dimensions_size < 1%
.
void mini_jit::TensorOptimization::_shared_identification(TensorConfig &config)
3. Lower the optimized IR code to your tensor operation backend¶
Since our IR is the struct TensorConfig
, we only need to provide the configuration to our optimization, and then to our tensor operation
setup. This order ensures that the optimizer creates a valid configuration for the tensor operation.
mini_jit::TensorOperation::error_t mini_jit::TensorOperation::setup(const TensorConfig &config)
{
mini_jit::TensorOptimization optimization;
TensorOperation::config = optimization.optimize(config);
return setup_no_optimization(TensorOperation::config.dtype, TensorOperation::config.first_touch, TensorOperation::config.main,
TensorOperation::config.last_touch, TensorOperation::config.dim_types, TensorOperation::config.exec_types,
TensorOperation::config.dim_sizes, TensorOperation::config.strides_in0, TensorOperation::config.strides_in1,
TensorOperation::config.strides_out);
}
Our TensorOptimization
‘s optimize
method executes individual optimization passes on the config struct.
mini_jit::TensorConfig mini_jit::TensorOptimization::optimize(TensorConfig config)
{
_dimension_reordering_fusing(config);
_dimension_splitting(config);
_dimension_fusing(config);
_primitive_identification(config);
_dimension_reordering_shared(config);
// Only call shared after reordering it only parallelize the first loops until the first seq k-loops at maximum
_shared_identification(config);
return config;
}
4. Benchmark the performance of your implementation¶
File: TensorOptimization.bench.cpp
Matrix multiplication example
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_mean 1316172 ns 1303763 ns 10 411.786G/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_median 1313935 ns 1303515 ns 10 411.864G/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_stddev 7770 ns 1120 ns 10 353.7M/s
BM_optimized_tensor_GEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:0/min_warmup_time:0.300_cv 0.59 % 0.09 % 10 0.09%
Tensor contraction example
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
-------------------------------------------------------------------------------------------------------------------------------------------------------------------
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_mean 1310327 ns 1295379 ns 10 414.451G/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_median 1307359 ns 1295362 ns 10 414.456G/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_stddev 8579 ns 1229 ns 10 393.184M/s
BM_optimized_tensor_BRGEMM/size_a:2560000/size_b:2560000/size_c:2560000/config:1/min_warmup_time:0.300_cv 0.65 % 0.09 % 10 0.09%
5. Demonstrate the capabilities of your optimization passes¶
We tested our optimization passes in TensorOptimization.test.cpp
. One exhaustive test case is shown below. This optimization involves
primitive reordering
, fusing
, primitive identification
, and shared identification
. In addition to testing the correctness of the tensor
configuration after the optimization passes, we also test the correctness of the tensor operation.
TEST_CASE("Test tensor operation with optimization dimension test reordering and fusing", "[tensor_optimization][gemm][correctness]")
{
using namespace mini_jit;
mini_jit::TensorConfig config{
mini_jit::TensorConfig::prim_t::none, // first_touch
mini_jit::TensorConfig::prim_t::gemm, // main
mini_jit::TensorConfig::prim_t::none, // last touch
{mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k, mini_jit::TensorConfig::dim_t::m, mini_jit::TensorConfig::dim_t::n,
mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k}, // dim_types
{mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq,
mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::seq}, // exec_types
{32, 8, 32, 5, 32, 32}, // dim_sizes
{0, 1024, 1, 0, 0, 32}, // strides_in0
{8192, 1024, 0, 8192 * 32, 32, 1}, // strides_in1
{1024, 0, 1, 32768, 32, 0}, // strides_out
mini_jit::TensorConfig::dtype_t::fp32, // dtype_t
};
mini_jit::TensorConfig expected{
mini_jit::TensorConfig::prim_t::none, // first_touch
mini_jit::TensorConfig::prim_t::gemm, // main
mini_jit::TensorConfig::prim_t::none, // last touch
{mini_jit::TensorConfig::dim_t::n, mini_jit::TensorConfig::dim_t::k, mini_jit::TensorConfig::dim_t::m, mini_jit::TensorConfig::dim_t::n,
mini_jit::TensorConfig::dim_t::k}, // dim_types
{mini_jit::TensorConfig::exec_t::shared, mini_jit::TensorConfig::exec_t::seq, mini_jit::TensorConfig::exec_t::prim,
mini_jit::TensorConfig::exec_t::prim, mini_jit::TensorConfig::exec_t::prim}, // exec_types
{5 * 32, 8, 32, 32, 32}, // dim_sizes
{0, 1024, 1, 0, 32}, // strides_in0
{8192, 1024, 0, 32, 1}, // strides_in1
{1024, 0, 1, 32, 0}, // strides_out
mini_jit::TensorConfig::dtype_t::fp32, // dtype_t
};
mini_jit::TensorOperation tensor_op;
TensorOperation::error_t err = tensor_op.setup(config);
INFO(tensor_op.get_config().to_string());
REQUIRE(err == TensorOperation::error_t::success);
REQUIRE_FALSE(mini_jit::TensorConfig::equals(config, tensor_op.get_config()));
REQUIRE(mini_jit::TensorConfig::equals(expected, tensor_op.get_config()));
GenerationTest test(32, 32, 32, 32 * 1 * 32 * 8 * 1 * 1, 32 * 32 * 1 * 8 * 32 * 5, 1 * 32 * 32 * 1 * 32 * 5);
test.SetUp(TestInfill::Random);
tensor_op.execute(test.matrix_a.data(), test.matrix_b.data(), test.matrix_c.data());
for (int64_t i0 = 0; i0 < expected.dim_sizes[0]; i0++)
{
for (int64_t i1 = 0; i1 < expected.dim_sizes[1]; i1++)
{
uint64_t offset_a = i0 * expected.strides_in0[0] + i1 * expected.strides_in0[1];
uint64_t offset_b = i0 * expected.strides_in1[0] + i1 * expected.strides_in1[1];
uint64_t offset_c = i0 * expected.strides_out[0] + i1 * expected.strides_out[1];
test.naive_matmul_M_N_K_Batch(test.matrix_a.data() + offset_a, test.matrix_b.data() + offset_b,
test.matrix_c_verify.data() + offset_c, 32, 32, 32, 32 * 32, 32 * 32);
}
}
test.verify_matmul(test.matrix_c_verify.data(), test.matrix_c.data(), test.matrix_c.size());
}