Vectorize This Loop

Hi

I've seen a really cool package over at Julia's discourse (see here) and tried to replicate their gemmavx! function in Rust using the ndarray crate:

pub fn gemm(c: &mut Array2<f64>, a: &Array2<f64>, b: &Array2<f64>) {
    assert!(c.len_of(Axis(0)) == a.len_of(Axis(0)));
    assert!(c.len_of(Axis(1)) == b.len_of(Axis(1)));
    
    assert!(a.len_of(Axis(0)) == a.len_of(Axis(1)));
    assert!(b.len_of(Axis(0)) == b.len_of(Axis(1)));
    assert!(c.len_of(Axis(0)) == c.len_of(Axis(1)));

    for i in 0..a.len_of(Axis(0)) {
        for j in 0..b.len_of(Axis(1)) {
            let mut cij = 0.0;
            for k in 0..a.len_of(Axis(1)) {
                cij += a[[i,k]] * b[[k,j]];
            }
            c[[i,j]] = cij;
        }
    }
}

I wanted to see whether Rust would also vectorize this loop, but unfortunately it seems to be 10 times slower than the Julia version!

julia> n = 100
100
julia> @btime gemmavx!($(Array{Float64, 2}(undef, n, n)), $(rand(n, n)), $(rand(n, n)))
  72.696 μs (0 allocations: 0 bytes)

Versus:

$ RUSTFLAGS="-C opt-level=3 -C target-cpu=native" cargo bench
[...]
Running target/release/deps/bench_gemm-0e256c20673cf1b5
gemm                    time:   [743.45 us 752.20 us 762.68 us]

I've investigated what code is generated using:

RUSTFLAGS="-C opt-level=3 -C target-cpu=native" cargo rustc --release --lib -- --emit asm

In the generated assembly, some instructions, that I believe are a sign of vectorization, are used:

vmovsd  (%rbp), %xmm1
vmulsd  (%rcx), %xmm1, %xmm1

But most instructions aren't vector instructions:

addq    %rdx, %rcx
addq    %rsi, %rbp
decq    %rbx

I assumed bounds checks may be hintering optimization, which is why I've added extra asserts that should give the compiler extra information, however I'm not sure they suffice!

In the full assembly (see below) it jumps to .LBB8_19 after a comparison quite often.

.LBB8_19 calls core..ops..index..Index of the ndarray::ArrayBase, which then calls *_ZN7ndarray11arraytraits19array_out_of_bounds17h072d69751b863d8dE@GOTPCREL (possibly panicing after this?).

So it seems like there still are a bunch of bounds checks remaining.
Unfortunately, this is where I don't know how to further analyze the assembly - whose bounds are being checked? Which part of the Rust code corresponds to each of the labels?

How could I make this code run as fast as the Julia version?
What's hindering the loop vectorization from taking place?

Here's the full asm generated by Rust:

Rust ASM

The gemm function starts at line 138. (Search for "_ZN3lib4gemm17h246dcfd4100df471E")

	.text
	.file	"lib.ezkstiu3-cgu.0"
	.section	".text._ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE","ax",@progbits
	.p2align	4, 0x90
	.type	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE,@function
_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE:
	movabsq	$7549865886324542212, %rax
	retq
.Lfunc_end0:
	.size	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE, .Lfunc_end0-_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE

	.section	".text._ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E","ax",@progbits
	.p2align	4, 0x90
	.type	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E,@function
_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E:
	movabsq	$1229646359891580772, %rax
	retq
.Lfunc_end1:
	.size	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E, .Lfunc_end1-_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E

	.section	.text._ZN3std9panicking11begin_panic17h5d87dff5f73669a9E,"ax",@progbits
	.p2align	4, 0x90
	.type	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E,@function
_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E:
	.cfi_startproc
	subq	$24, %rsp
	.cfi_def_cfa_offset 32
	movq	%rdi, 8(%rsp)
	movq	%rsi, %rcx
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.0(%rip), %rsi
	leaq	8(%rsp), %rdi
	xorl	%edx, %edx
	movq	$56, 16(%rsp)
	callq	*_ZN3std9panicking20rust_panic_with_hook17hbe174577402a475dE@GOTPCREL(%rip)
	ud2
.Lfunc_end2:
	.size	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E, .Lfunc_end2-_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	.cfi_endproc

	.section	.text._ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E,"ax",@progbits
	.p2align	4, 0x90
	.type	_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E,@function
_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E:
	retq
.Lfunc_end3:
	.size	_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E, .Lfunc_end3-_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E

	.section	.text._ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E,"ax",@progbits
	.p2align	4, 0x90
	.type	_ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E,@function
_ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E:
	retq
.Lfunc_end4:
	.size	_ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E, .Lfunc_end4-_ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E

	.section	".text._ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E","ax",@progbits
	.p2align	4, 0x90
	.type	_ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E,@function
_ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E:
	pushq	%rax
	callq	*_ZN7ndarray11arraytraits19array_out_of_bounds17h072d69751b863d8dE@GOTPCREL(%rip)
	ud2
.Lfunc_end5:
	.size	_ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E, .Lfunc_end5-_ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E

	.section	".text._ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E","ax",@progbits
	.p2align	4, 0x90
	.type	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E,@function
_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E:
	cmpq	$0, (%rdi)
	movl	$1, %eax
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.2(%rip), %rcx
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.1(%rip), %rdx
	cmovneq	%rdi, %rax
	cmoveq	%rcx, %rdx
	retq
.Lfunc_end6:
	.size	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E, .Lfunc_end6-_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E

	.section	".text._ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E","ax",@progbits
	.p2align	4, 0x90
	.type	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E,@function
_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E:
	.cfi_startproc
	pushq	%r14
	.cfi_def_cfa_offset 16
	pushq	%rbx
	.cfi_def_cfa_offset 24
	pushq	%rax
	.cfi_def_cfa_offset 32
	.cfi_offset %rbx, -24
	.cfi_offset %r14, -16
	movq	(%rdi), %rbx
	movq	8(%rdi), %r14
	movq	$0, (%rdi)
	testq	%rbx, %rbx
	je	.LBB7_1
	movl	$16, %edi
	movl	$8, %esi
	callq	*__rust_alloc@GOTPCREL(%rip)
	testq	%rax, %rax
	je	.LBB7_5
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.1(%rip), %rdx
	movq	%rbx, (%rax)
	movq	%r14, 8(%rax)
	addq	$8, %rsp
	.cfi_def_cfa_offset 24
	popq	%rbx
	.cfi_def_cfa_offset 16
	popq	%r14
	.cfi_def_cfa_offset 8
	retq
.LBB7_1:
	.cfi_def_cfa_offset 32This text will be hidden

	movl	$1, %eax
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.2(%rip), %rdx
	addq	$8, %rsp
	.cfi_def_cfa_offset 24
	popq	%rbx
	.cfi_def_cfa_offset 16
	popq	%r14
	.cfi_def_cfa_offset 8
	retq
.LBB7_5:
	.cfi_def_cfa_offset 32
	movl	$16, %edi
	movl	$8, %esi
	callq	*_ZN5alloc5alloc18handle_alloc_error17hf59328f931d332cdE@GOTPCREL(%rip)
	ud2
.Lfunc_end7:
	.size	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E, .Lfunc_end7-_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E
	.cfi_endproc

	.section	.text._ZN3lib4gemm17h246dcfd4100df471E,"ax",@progbits
	.globl	_ZN3lib4gemm17h246dcfd4100df471E
	.p2align	4, 0x90
	.type	_ZN3lib4gemm17h246dcfd4100df471E,@function
_ZN3lib4gemm17h246dcfd4100df471E:
	.cfi_startproc
	pushq	%rbp
	.cfi_def_cfa_offset 16
	pushq	%r15
	.cfi_def_cfa_offset 24
	pushq	%r14
	.cfi_def_cfa_offset 32
	pushq	%r13
	.cfi_def_cfa_offset 40
	pushq	%r12
	.cfi_def_cfa_offset 48
	pushq	%rbx
	.cfi_def_cfa_offset 56
	pushq	%rax
	.cfi_def_cfa_offset 64
	.cfi_offset %rbx, -56
	.cfi_offset %r12, -48
	.cfi_offset %r13, -40
	.cfi_offset %r14, -32
	.cfi_offset %r15, -24
	.cfi_offset %rbp, -16
	movq	32(%rdi), %r15
	cmpq	32(%rsi), %r15
	jne	.LBB8_20
	movq	40(%rdi), %rax
	cmpq	40(%rdx), %rax
	jne	.LBB8_21
	cmpq	40(%rsi), %r15
	jne	.LBB8_22
	cmpq	%rax, 32(%rdx)
	jne	.LBB8_7
	cmpq	%rax, %r15
	jne	.LBB8_8
	testq	%r15, %r15
	je	.LBB8_23
	movq	48(%rsi), %r8
	movq	24(%rsi), %r14
	movq	56(%rsi), %rsi
	movq	56(%rdx), %r9
	movq	24(%rdx), %rax
	movq	48(%rdx), %rdx
	xorl	%r11d, %r11d
	movq	%rax, (%rsp)
	shlq	$3, %r9
	shlq	$3, %rdx
	shlq	$3, %r8
	shlq	$3, %rsi
	.p2align	4, 0x90
.LBB8_11:
	movq	(%rsp), %r10
	movq	%r11, %r12
	incq	%r11
	xorl	%eax, %eax
	.p2align	4, 0x90
.LBB8_12:
	cmpq	%r12, %r15
	jbe	.LBB8_19
	cmpq	%r15, %rax
	je	.LBB8_19
	leaq	1(%rax), %r13
	vxorpd	%xmm0, %xmm0, %xmm0
	movq	%r15, %rbx
	movq	%r14, %rbp
	movq	%r10, %rcx
	.p2align	4, 0x90
.LBB8_15:
	testq	%rbx, %rbx
	je	.LBB8_19
	vmovsd	(%rbp), %xmm1
	vmulsd	(%rcx), %xmm1, %xmm1
	addq	%rdx, %rcx
	addq	%rsi, %rbp
	decq	%rbx
	vaddsd	%xmm1, %xmm0, %xmm0
	jne	.LBB8_15
	cmpq	%r12, 32(%rdi)
	jbe	.LBB8_19
	cmpq	%rax, 40(%rdi)
	jbe	.LBB8_19
	movq	48(%rdi), %rcx
	movq	24(%rdi), %rbx
	imulq	56(%rdi), %rax
	addq	%r9, %r10
	imulq	%r12, %rcx
	addq	%rcx, %rax
	vmovsd	%xmm0, (%rbx,%rax,8)
	movq	%r13, %rax
	cmpq	%r15, %r13
	jne	.LBB8_12
	addq	%r8, %r14
	cmpq	%r15, %r11
	jne	.LBB8_11
.LBB8_23:
	addq	$8, %rsp
	.cfi_def_cfa_offset 56
	popq	%rbx
	.cfi_def_cfa_offset 48
	popq	%r12
	.cfi_def_cfa_offset 40
	popq	%r13
	.cfi_def_cfa_offset 32
	popq	%r14
	.cfi_def_cfa_offset 24
	popq	%r15
	.cfi_def_cfa_offset 16
	popq	%rbp
	.cfi_def_cfa_offset 8
	retq
.LBB8_19:
	.cfi_def_cfa_offset 64
	callq	_ZN7ndarray11arraytraits94_$LT$impl$u20$core..ops..index..Index$LT$I$GT$$u20$for$u20$ndarray..ArrayBase$LT$S$C$D$GT$$GT$5index28_$u7b$$u7b$closure$u7d$$u7d$17h9b147b7c648892b0E
	ud2
.LBB8_20:
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.5(%rip), %rdi
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.4(%rip), %rsi
	callq	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	ud2
.LBB8_21:
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.7(%rip), %rdi
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.6(%rip), %rsi
	callq	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	ud2
.LBB8_22:
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.9(%rip), %rdi
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.8(%rip), %rsi
	callq	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	ud2
.LBB8_7:
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.11(%rip), %rdi
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.10(%rip), %rsi
	callq	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	ud2
.LBB8_8:
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.13(%rip), %rdi
	leaq	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.12(%rip), %rsi
	callq	_ZN3std9panicking11begin_panic17h5d87dff5f73669a9E
	ud2
.Lfunc_end8:
	.size	_ZN3lib4gemm17h246dcfd4100df471E, .Lfunc_end8-_ZN3lib4gemm17h246dcfd4100df471E
	.cfi_endproc

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.0,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.0,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.0:
	.quad	_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E
	.quad	16
	.quad	8
	.quad	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$9box_me_up17h7c1567326ae745e4E
	.quad	_ZN91_$LT$std..panicking..begin_panic..PanicPayload$LT$A$GT$$u20$as$u20$core..panic..BoxMeUp$GT$3get17h70c1aee6dcd93f31E
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.0, 40

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.1,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.1,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.1:
	.quad	_ZN4core3ptr18real_drop_in_place17h7ac74da8cab3cb50E
	.quad	16
	.quad	8
	.quad	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17hf8423819df65ab72E
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.1, 32

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.2,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.2,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.2:
	.quad	_ZN4core3ptr18real_drop_in_place17hcddde482e0ab8a07E
	.quad	0
	.quad	1
	.quad	_ZN36_$LT$T$u20$as$u20$core..any..Any$GT$7type_id17h3d618e91576dd15fE
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.2, 32

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.3,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3:
	.ascii	"lib/lib.rs"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3, 10

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.4,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.4,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.4:
	.quad	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3
	.asciz	"\n\000\000\000\000\000\000\000\017\000\000\000\005\000\000"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.4, 24

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.5,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.5,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.5:
	.ascii	"assertion failed: c.len_of(Axis(0)) == a.len_of(Axis(0))"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.5, 56

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.6,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.6,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.6:
	.quad	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3
	.asciz	"\n\000\000\000\000\000\000\000\020\000\000\000\005\000\000"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.6, 24

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.7,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.7,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.7:
	.ascii	"assertion failed: c.len_of(Axis(1)) == b.len_of(Axis(1))"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.7, 56

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.8,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.8,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.8:
	.quad	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3
	.asciz	"\n\000\000\000\000\000\000\000\022\000\000\000\005\000\000"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.8, 24

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.9,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.9,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.9:
	.ascii	"assertion failed: a.len_of(Axis(0)) == a.len_of(Axis(1))"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.9, 56

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.10,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.10,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.10:
	.quad	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3
	.asciz	"\n\000\000\000\000\000\000\000\023\000\000\000\005\000\000"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.10, 24

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.11,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.11,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.11:
	.ascii	"assertion failed: b.len_of(Axis(0)) == b.len_of(Axis(1))"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.11, 56

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.12,@object
	.section	.data.rel.ro..Lanon.72f66a15dc35360d0e71560ae91f6ce8.12,"aw",@progbits
	.p2align	3
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.12:
	.quad	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.3
	.asciz	"\n\000\000\000\000\000\000\000\024\000\000\000\005\000\000"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.12, 24

	.type	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.13,@object
	.section	.rodata..Lanon.72f66a15dc35360d0e71560ae91f6ce8.13,"a",@progbits
.Lanon.72f66a15dc35360d0e71560ae91f6ce8.13:
	.ascii	"assertion failed: c.len_of(Axis(0)) == c.len_of(Axis(1))"
	.size	.Lanon.72f66a15dc35360d0e71560ae91f6ce8.13, 56


	.section	".note.GNU-stack","",@progbits

Here's the complete project for reproducing this locally:

Complete Project
`benches/bench_gemm.rs`
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use ndarray::*;
use rand::prelude::*;
use lib::gemm;

fn rand_array(n: usize) -> Array2<f64> {
    let mut rng = rand::thread_rng();
    let mut a = Array2::zeros((n, n));
    for i in 0..a.len_of(Axis(0)) {
        for j in 0..a.len_of(Axis(1)) {
            a[[i,j]] = rng.gen();
        }
    }
    a
}

fn criterion_benchmark(c: &mut Criterion) {
    let n = 100;

    let mut ac = Array2::zeros((n, n));
    let aa = rand_array(n);
    let ab = rand_array(n);

    c.bench_function("gemm", |b| b.iter(|| {
        gemm(black_box(&mut ac), black_box(&aa), black_box(&ab));
    }));
}


criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
`lib/lib.rs`
use ndarray::*;

// function gemmavx!(C, A, B)
//    @avx for i ∈ 1:size(A,1), j ∈ 1:size(B,2)
//        Cᵢⱼ = 0.0
//        for k ∈ 1:size(A,2)
//            Cᵢⱼ += A[i,k] * B[k,j]
//        end
//        C[i,j] = Cᵢⱼ
//    end
// end


pub fn gemm(c: &mut Array2<f64>, a: &Array2<f64>, b: &Array2<f64>) {
    assert!(c.len_of(Axis(0)) == a.len_of(Axis(0)));
    assert!(c.len_of(Axis(1)) == b.len_of(Axis(1)));
    
    assert!(a.len_of(Axis(0)) == a.len_of(Axis(1)));
    assert!(b.len_of(Axis(0)) == b.len_of(Axis(1)));
    assert!(c.len_of(Axis(0)) == c.len_of(Axis(1)));

    for i in 0..a.len_of(Axis(0)) {
        for j in 0..b.len_of(Axis(1)) {
            let mut cij = 0.0;
            for k in 0..a.len_of(Axis(1)) {
                cij += a[[i,k]] * b[[k,j]];
            }
            c[[i,j]] = cij;
        }
    }
}
`Cargo.toml`
[package]
name = "gemmavx"
version = "0.1.0"
authors = ["ambiso <ambiso@invalid>"]
edition = "2018"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
ndarray = "*"

[dev-dependencies]
criterion = "*"
rand = "*"

[[bench]]
name = "bench_gemm"
harness = false

[lib]
name = "lib"
path = "lib/lib.rs"

[profile.release]
opt-level = 3
codegen-units = 1
panic = "abort"
lto = true
`src/main.rs`
fn main() {
    println!("Hello, world!");
}

Thanks in advance for any hints!

Best,
ambiso

I've packaged the project up so you can experiment locally with it instantly:

1 Like

Out of curiosity, why would not you use ndarray::BaseArray::dot method in the Rust's case? Ndarray is not even native Rust's data structure, that will be really hard for compiler to reason about it.

As I got Julia's loop is not exactly the native lang loop either since macros's applied transformation, I mean it's not native Julia's compiler optimization. So it seems it would be equal to use optimized library methods rather than rely on compiler optimization in this case.

ndarray should be included with features=["blas"]. Actually all above as advised in crate's performance section https://docs.rs/ndarray/0.13.0/ndarray/

Dot seems to be allocating a new output vector/matrix depending on the input?
Could you elaborate how you'd use dot?

Yes, I was hoping Rust + LLVM could obtain the same optimization.

This will only change something if I actually use a function that can be implemented in terms of BLAS, but I don't think I do?
And the point would be that you don't have to write a function in terms of BLAS functions, but can write your own.

Note that the "sd"-prefixed SSE/AVX functions only operate on a single scalar f64 value at a time (that's what the penultimate "s" stands for, SIMD versions are prefixed with "pd" for Packed Double-precision), so this part of the code isn't vectorized either.

Bounds checks are easier to elide for a JIT-compiled language like Julia than for an AoT-compiled language like Rust, because the JIT compiler works with knowledge of the size of the array under study whereas the AoT compiler doesn't. Sometimes AoT compilers manage to catch up with JITs by hoisting bounds checks out of the loop (or leveraging length assertions of the kind you've written above), sometimes they don't, and it's a bit hard to predict in advance.

When the compiler is too dumb to move bound checks out of the loop, two general strategies to resolve the problem are to 1/try to rewrite the program in a shape which the compiler can handle well (the classic slice-whole-vector trick) and to 2/give up on guaranteed memory safety and use unsafe accesses which don't have bounds checks by definition (this is what julia's @inbounds does IIRC).

One then usually tries to encapsulate either of these ugly tricks in a general safe abstraction so that they are usable in many contexts, e.g. this is what Rust iterators do for linear iteration patterns.

1 Like

Aha!

Julia is JIT AoT compiled - like @dunnock mentioned, this is not a transformation done by LLVM but the @avx macro.
I was hoping, that maybe I can give rust enough assurances, about the array bounds, that it would vectorize the loop.

I'll try to rewrite the code as to remove any bounds checks, and see if that helps!

2 Likes