Submission 2025-06-12¶
Unary Operation¶
The support for none transposed unary operations was already added in the last submission. Therefore, we only needed to include the transpose operation additionally.
We added transpose support to parse our TensorConfig
in the TensorOperation.cpp
.
And validated with some additional tests: File: TensorOperation.test.cpp
.
bool mini_jit::TensorOperation::isValidPrimStrides(const std::span<const TensorConfig::dim_t> &dim,
const std::span<const TensorConfig::exec_t> &exec,
const std::span<const int64_t> &strides_in0, const std::span<const int64_t> &strides_out,
const TensorConfig::prim_t main_prim)
{
// ...
// no transpose
if (isExpectedStride(1, indexM, strides_in0) && isExpectedStride(1, indexM, strides_out))
{
return true;
}
// Check transpose in unary op
if (isUnary(main_prim) && isExpectedStride(1, indexM, strides_in0) && isExpectedStride(1, indexN, strides_out))
{
isTranspose = true;
return true;
}
// ...
}
Einsum Trees - Lowering¶
This section expands the capabilities of our tensor compiler by adding support for einsum trees. Specifically, we execute einsum trees by mapping them to a tree of unary and binary tensor operations. These operations can then be executed by our tensor operation backend.
1. Function that parses string of tree and sorted dimension sizes¶
First, we implemented a struct called EinsumNode
to parse the string representation of a tree and the numerically sorted dimension sizes.
This structure holds one node of the tree, its possible children, dimension sizes, and a tensor representing an intermediate or final
(root node) result.
struct EinsumNode
{
NodeType type;
float *tensor;
// Always filled — dims of the output tensor
std::vector<int64_t> output_dim_ids;
// Pointers to children
EinsumNode *left = nullptr;
EinsumNode *right = nullptr;
/**
* Gets a string representation of the einsum tree.
*/
std::string to_string() const;
/**
* Get the size of the tensor represented by this node.
*
* @param dim_sizes A vector of dimension sizes corresponding to the output dimensions.
*/
int64_t get_size(const std::vector<int64_t> dim_sizes) const;
private:
/**
* This method recursively formats the node and its children into a string.
*
* @param depth The current depth in the tree, used for indentation.
* @param connection A string representing the connection type.
* @param depthString A string representation of the current depth.
* @return A formatted string representing the einsum tree.
*/
std::string _to_string(uint depth, std::string connection, std::string depthString) const;
};
Then, we implemented the logic to parse the string into a set of nodes in the parse_tree_no_optimization()
method. This method also indicates whether
the parsing was successful, ErrorParse
.
ErrorParse parse_tree_no_optimization();
// AND
EinsumNode *parse_node(size_t &pos, const std::string &str);
2. Function that lowers the contraction and permutation to the tensor operation backend¶
To lower our tree to the tensor operation backend, each EinsumNode
is lowered to one TensorConfig
. This configuration can then be
passed to the TensorOperation
. The main method for doing so is lower_node
.
TensorConfig lower_node(const EinsumNode *node);
3. Run your optimization passes on the lowered tensor operations¶
Our EinsumTree
has an execute()
method. This method recursively executes one tensor operation per EinsumNode
. Therefore, the
TensorConfig
of the current node is used as input for the TensorOperation
. Since our TensorOperation
receives a TensorConfig
as input, it runs all optimization passes on the config before executing the operation. Therefore, no additional step is needed to run
optimization passes on the lowered tensor operations.
To ensure the success of all tensor operations, the methods return an ErrorExecute
.
ErrorExecute execute(std::vector<void *> tensors);
// AND
ErrorExecute execute_node(EinsumNode *node);
4. Benchmark the performance¶
------------------------------------------------------------------------------------------------------------------------------
Benchmark Time CPU Iterations FLOPS
------------------------------------------------------------------------------------------------------------------------------
BM_einsum_tree_first_example/einsum_tree:0/min_warmup_time:0.300_mean 280607767 ns 279227060 ns 10 142.03G/s
BM_einsum_tree_first_example/einsum_tree:0/min_warmup_time:0.300_median 277448741 ns 276113901 ns 10 143.454G/s
BM_einsum_tree_first_example/einsum_tree:0/min_warmup_time:0.300_stddev 10891315 ns 10817141 ns 10 5.02424G/s
BM_einsum_tree_first_example/einsum_tree:0/min_warmup_time:0.300_cv 3.88 % 3.87 % 10 3.54%
BM_einsum_tree_second_example/einsum_tree:1/min_warmup_time:0.300_mean 12415368 ns 12304609 ns 10 249.808G/s
BM_einsum_tree_second_example/einsum_tree:1/min_warmup_time:0.300_median 12389493 ns 12296296 ns 10 249.965G/s
BM_einsum_tree_second_example/einsum_tree:1/min_warmup_time:0.300_stddev 98826 ns 90496 ns 10 1.83123G/s
BM_einsum_tree_second_example/einsum_tree:1/min_warmup_time:0.300_cv 0.80 % 0.74 % 10 0.73%
First Example: \(143.4\) GiB/s
Second Example: \(249.9\) GiB/s