Machine Learning Compilers¶
GETTING STARTED
SUBMISSIONS
- Submission 2025-04-17
- Hello Assembly
- Assembly Function
- 1. Assemble the file and use the name
add_values.o
for the output. - 2. Generate the following from
add_values.o
: - 3. Find the size of the
.text
section in the generated output and explain it. - 4. Write a C++ driver that calls the
add_values
function and illustrate it with an example. - 5. Use the GNU Project Debugger GDB to step through an example call to the
add_values
function. Display the contents of the general-purpose registers after each of the executed instructions.
- 1. Assemble the file and use the name
- Submission 2025-04-24
- Submission 2025-05-01
- Execution Throughput and Latency
- Microkernel
- Loops
- 1. Loop over K: Implement a kernel that computes C+=AB for M=16, N=6 and K=64.
- 2. Loop over M: Implement a kernel that computes C+=AB for M=64, N=6 and K=64.
- 3. Loop over N: Implement a kernel that computes C+=AB for M=64, N=48 and K=64.
- 1. Test and optimize the kernels. Report your performance in GFLOPS.
- Submission 2025-05-08
- Submission 2025-05-15
- Neon Batch-Reduce GEMM
- GEMM
- 1. Extend generate to support M-N-K combinations for column-major format \(1 \leq M,N \leq 1024, 1 \leq K \leq 2028\)
- 2. Verify all matrices for
1≤M≤64
,1≤N≤64
,K∈[1,16,32,64,128]
,``lda=M``,ldb=K
, andldc=M
- 3. Verify all matrices for
1≤M≤64
,1≤N≤64
,K∈[1,16,32,64,128]
,``lda>M``,ldb>K
, andldc>M
- 4. Benchmark for
1≤M≤64
,1≤N≤64
,K∈[1,16,32,64,128]
,``lda=M``,ldb=K
, andldc=M
.
- Batch-Reduce GEMM
- Submission 2025-05-22
- Transposition
- Unary Primitives
- Zero Primitive
- 1. mini_jit::Unary::generate function to support the zero primitive
- 2. Test and optimize
- Identity Primitive
- 1. mini_jit::Unary::generate function to support the identity primitive
- 2. Test and optimize
- ReLu Primitive
- 1. mini_jit::Unary::generate function to support the ReLu primitive
- 2. Test and optimize
- Submission 2025-05-29
- Submission 2025-06-05
- Submission 2025-06-12
- Submission 2025-06-19
- Individual Phase
API
- mini_jit
mini_jit::Brgemm
mini_jit::EinsumTree
ErrorParse
ErrorExecute
None
InvalidRoot
NotEnoughInputTensors
TooManyInputTensors
NullPtrAsInputTensor
err_wrong_dtype
err_wrong_dimension
err_wrong_primitive
err_wrong_first_touch_primitive
err_wrong_main_primitive
err_wrong_last_touch_primitive
err_execution_type_not_supported
err_invalid_primitive_configuration
err_invalid_first_touch_configuration
err_invalid_main_configuration
err_invalid_last_touch_configuration
err_invalid_execution_order
err_invalid_strides
err_k_dimension_must_not_be_shared
err_shared_required_for_parallel_execution
NodeType
EinsumTree()
EinsumTree()
~EinsumTree()
set_sorted_dim_sizes()
parse_tree_no_optimization()
parse_tree()
get_root()
optimize()
conditional_swap()
reorder_left_node()
reorder_right_node()
execute()
parse_node()
lower_node()
get_config_dim_types_and_sizes()
get_config_strides()
execute_node()
assign_tensor_indices()
is_unit_stride_n()
parse_dim_list()
compute_strides()
get_output_dims()
parse_setup_error()
delete_tree()
findKDim()
findNDim()
findMDim()
tensorIndex
root
tree_str
error_parse
dim_sizes
mini_jit::EinsumTree::EinsumNode
mini_jit::Kernel
mini_jit::TensorConfig
mini_jit::TensorOperation
error_t
success
err_wrong_dtype
err_wrong_dimension
err_wrong_primitive
err_wrong_first_touch_primitive
err_wrong_main_primitive
err_wrong_last_touch_primitive
err_execution_type_not_supported
err_invalid_primitive_configuration
err_invalid_first_touch_configuration
err_invalid_main_configuration
err_invalid_last_touch_configuration
err_invalid_execution_order
err_invalid_strides
err_k_dimension_must_not_be_shared
err_shared_required_for_parallel_execution
stride_t
setup()
setup_no_optimization()
execute()
execute_dimension()
get_config()
write_kernel_to_file()
isExpectedStride()
isValidStride()
isUnary()
isBrgemm()
findMatch()
isValidPrimConfig()
isValidPrimStrides()
isValidKDim()
isSortedConfiguration()
generateUnary()
config
dtype
prim_first
prim_main
prim_last
dim_types
exec_types
dim_sizes
strides_in0
strides_in1
strides_out
indexPrimM
indexPrimN
indexPrimK
indexPrimBatch
first_touch
main_kernel
last_touch
isParallel
isTranspose
hasSetupError
mini_jit::TensorOptimization
optimize()
optimize_primitive_identification()
optimize_shared_identification()
optimize_dimension_reordering_shared()
optimize_dimension_reordering_fusing()
optimize_dimension_splitting()
optimize_dimension_fusing()
_reorder_helper_adjust_index()
_primitive_identification()
_shared_identification()
_dimension_reordering_shared()
_dimension_reordering_fusing()
_swap_elements()
_move_elements()
_dimension_splitting()
_dimension_fusing()
thread_count
maximum_inbalanced_parallel_precentage
fuse_split_dimension_size
mini_jit::Unary
- arm_instructions
R32Bit
R64Bit
ShiftLSL
ShiftLSR
ShiftASR
ShiftROR
V8Bit
V16Bit
V32Bit
V64Bit
V128Bit
VGeneral
VType8x8Bit
VType16x8Bit
VType4x16Bit
VType8x16Bit
VType2x32Bit
VType4x32Bit
VType1x64Bit
VType2x64Bit
mask1
mask2
mask3
mask4
mask5
mask6
mask7
mask8
mask9
mask10
mask11
mask12
mask13
mask14
mask15
mask16
mask17
mask18
mask19
w0
w1
w2
w3
w4
w5
w6
w7
w8
w9
w10
w11
w12
w13
w14
w15
w16
w17
w19
w20
w21
w22
w23
w24
w25
w26
w27
w28
w29
w30
wsp
wzr
x0
x1
x2
x3
x4
x5
x6
x7
x8
x9
x10
x11
x12
x13
x14
x15
x16
x17
x19
x20
x21
x22
x23
x24
x25
x26
x27
x28
x29
x30
fp
lr
sp
xzr
LSL
LSR
ASR
ROR
b0
b1
b2
b3
b4
b5
b6
b7
b8
b9
b10
b11
b12
b13
b14
b15
b16
b17
b18
b19
b20
b21
b22
b23
b24
b25
b26
b27
b28
b29
b30
b31
h0
h1
h2
h3
h4
h5
h6
h7
h8
h9
h10
h11
h12
h13
h14
h15
h16
h17
h18
h19
h20
h21
h22
h23
h24
h25
h26
h27
h28
h29
h30
h31
s0
s1
s2
s3
s4
s5
s6
s7
s8
s9
s10
s11
s12
s13
s14
s15
s16
s17
s18
s19
s20
s21
s22
s23
s24
s25
s26
s27
s28
s29
s30
s31
d0
d1
d2
d3
d4
d5
d6
d7
d8
d9
d10
d11
d12
d13
d14
d15
d16
d17
d18
d19
d20
d21
d22
d23
d24
d25
d26
d27
d28
d29
d30
d31
q0
q1
q2
q3
q4
q5
q6
q7
q8
q9
q10
q11
q12
q13
q14
q15
q16
q17
q18
q19
q20
q21
q22
q23
q24
q25
q26
q27
q28
q29
q30
q31
v0
v1
v2
v3
v4
v5
v6
v7
v8
v9
v10
v11
v12
v13
v14
v15
v16
v17
v18
v19
v20
v21
v22
v23
v24
v25
v26
v27
v28
v29
v30
v31
t8b
t16b
t4h
t8h
t2s
t4s
t1d
t2d
add()
add()
add()
add()
add()
add()
add()
add()
cbnz()
cbnz()
ldpPost()
ldpPost()
ldpPre()
ldpPre()
ldp()
ldp()
ldpOffset()
ldpOffset()
ldrPost()
ldrPost()
ldrPre()
ldrPre()
ldr()
ldr()
ldrOffset()
ldrOffset()
lsl()
lsl()
madd()
madd()
mov()
mov()
mov()
mov()
mov()
mov()
movSp()
movSp()
movn()
movn()
movn()
movn()
movz()
movz()
movz()
movz()
orr()
orr()
orr()
orr()
ret()
ret()
stpPost()
stpPost()
stpPre()
stpPre()
stp()
stp()
stpOffset()
stpOffset()
sub()
sub()
sub()
sub()
eor()
eor()
fmax()
fmax()
fmax()
fmax()
fmax()
fmax()
fmla()
fmla()
fmla()
fmla()
fmla< VType4x16Bit >()
fmla< VType8x16Bit >()
ld1()
ld1()
ld1()
ld1()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1()
ld1()
ld1()
ld1()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ld1Post()
ldpPost()
ldpPost()
ldpPost()
ldpPre()
ldpPre()
ldpPre()
ldp()
ldp()
ldp()
ldpOffset()
ldpOffset()
ldpOffset()
ldrPost()
ldrPost()
ldrPost()
ldrPost()
ldrPost()
ldrPre()
ldrPre()
ldrPre()
ldrPre()
ldrPre()
ldr()
ldr()
ldr()
ldr()
ldr()
ldrOffset()
ldrOffset()
ldrOffset()
ldrOffset()
ldrOffset()
st1()
st1()
st1()
st1()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1()
st1()
st1()
st1()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
st1Post()
stpPost()
stpPost()
stpPost()
stpPre()
stpPre()
stpPre()
stp()
stp()
stp()
stpOffset()
stpOffset()
stpOffset()
strPost()
strPost()
strPost()
strPost()
strPost()
strPre()
strPre()
strPre()
strPre()
strPre()
str()
str()
str()
str()
str()
strOffset()
strOffset()
strOffset()
strOffset()
strOffset()
trn1()
trn1()
trn1()
trn1()
trn1()
trn1()
trn1()
trn2()
trn2()
trn2()
trn2()
trn2()
trn2()
trn2()
zip1()
zip1()
zip1()
zip1()
zip1()
zip1()
zip1()
zip2()
zip2()
zip2()
zip2()
zip2()
zip2()
zip2()
- internal
addShiftType
orrShiftType
subShiftType
eorSimdTypes
fmaxSzType
fmaxQType
fmaxFType
fmlaHalfPrecisionTypes
fmlaSingleDoublePrecisionTypes
ld1DataTypes
ld1Types
ldpSimdFpDataTypes
ldrSimdFpDataTypes
st1DataTypes
st1Types
stpSimdFpDataTypes
strSimdFpDataTypes
trn1SizeType
trn1QType
trn2SizeType
trn2QType
zip1SizeType
zip1QType
zip2SizeType
zip2QType
ld1ImmediateRm
st1ImmediateRm
_addParseShiftType()
_addParseShiftType< ShiftLSL >()
_addParseShiftType< ShiftLSR >()
_addParseShiftType< ShiftASR >()
addShiftedRegister()
addImmediate()
cbnz()
_ldpPostPreOffset()
ldpPost()
ldpPre()
ldpOffset()
ldrImmediatePost()
ldrImmediatePre()
ldrImmediateOffset()
lslImmediate()
madd()
movn()
movz()
_orrParseShiftType()
_orrParseShiftType< ShiftLSL >()
_orrParseShiftType< ShiftLSR >()
_orrParseShiftType< ShiftASR >()
_orrParseShiftType< ShiftROR >()
orrShiftedRegister()
ret()
_stpPostPreOffset()
stpPost()
stpPre()
stpOffset()
subImmediate()
eorVector()
fmaxVector()
fmaxScalar()
_fmlaIsDouble()
_fmlaIsDouble< VType2x32Bit >()
_fmlaIsDouble< VType4x32Bit >()
_fmlaIsDouble< VType2x64Bit >()
_fmlaParseSingleDoubleType()
_fmlaParseSingleDoubleType< VType2x32Bit >()
_fmlaParseSingleDoubleType< VType4x32Bit >()
_fmlaParseSingleDoubleType< VType2x64Bit >()
fmlaByElementScalarHalfPrecision()
fmlaByElementScalarSingleDoublePrecision()
fmlaByElementVectorHalfPrecision()
fmlaByElementVectorSingleDoublePrecision()
_ld1ParseType()
_ld1ParseType< VType8x8Bit >()
_ld1ParseType< VType16x8Bit >()
_ld1ParseType< VType4x16Bit >()
_ld1ParseType< VType8x16Bit >()
_ld1ParseType< VType2x32Bit >()
_ld1ParseType< VType4x32Bit >()
_ld1ParseType< VType1x64Bit >()
_ld1ParseType< VType2x64Bit >()
_ld1GetQAndSize()
_ld1GetOpCode()
ld1MultipleStructures()
ld1MultipleStructuresPost()
ld1SingleStructures()
ld1SingleStructuresPost()
_ldpSimdFpPostPreOffset()
ldpPost()
ldpPre()
ldpOffset()
_ldrSimdFpGetOpCode()
ldrSimdFpImmediatePost()
ldrSimdFpImmediatePre()
ldrSimdFpImmediateOffset()
_st1ParseType()
_st1ParseType< VType8x8Bit >()
_st1ParseType< VType16x8Bit >()
_st1ParseType< VType4x16Bit >()
_st1ParseType< VType8x16Bit >()
_st1ParseType< VType2x32Bit >()
_st1ParseType< VType4x32Bit >()
_st1ParseType< VType1x64Bit >()
_st1ParseType< VType2x64Bit >()
_st1GetQAndSize()
_st1GetOpCode()
st1MultipleStructures()
st1MultipleStructuresPost()
st1SingleStructures()
st1SingleStructuresPost()
_stpSimdFpPostPreOffset()
stpPost()
stpPre()
stpOffset()
_strSimdFpGetOpCode()
strSimdFpImmediatePost()
strSimdFpImmediatePre()
strSimdFpImmediateOffset()
_trn1()
_trn2()
_zip1()
_zip2()
- kernels
br_matmul_16m_4n_k()
br_matmul_16m_lt4nRest_k()
br_matmul_16mRest_4n_k()
br_matmul_16mRest_lt4nRest_k()
br_matmul_lt16_4n_k()
br_matmul_lt16_lt4nRest_k()
matmul_16_6_1()
matmul_16_6_k()
matmul_16m_4n_k()
matmul_16m_lt4nRest_k()
matmul_16mRest_4n_k()
matmul_16mRest_lt4nRest_k()
matmul_lt16_4n_k()
matmul_lt16_lt4nRest_k()
unary_identity()
unary_identity_transpose()
unary_relu()
unary_relu_transpose()
unary_zero()
unary_zero_16m_n()