Suboptimal codegen using carrying_add for arbitrary precision

I saw that u8::carrying_add has been stabilized, so I looked at the codegen when manually implementing a u128 add. It seems very bad. Am I using it wrong?

A godbolt link comparing some different attempts is here

None manage to replicate the assembly of u128::add, but the naive version is a lot closer than the u8::carrying_add version.

copy of code from godbolt link
#[unsafe(no_mangle)]
fn a(a: u8, b: u8, carry: bool) -> (u8,bool) {
    u8::carrying_add(a,b,carry)
}
#[unsafe(no_mangle)]
fn b(a: u8, b: u8, carry: bool) -> (u8,bool) {
    adc(a,b,carry)
}

#[unsafe(no_mangle)]
pub fn myadd(num1: u128, num2: u128) -> u128 {
    num1 + num2
}

// add with carry
fn adc(num1: u8, num2: u8, carry: bool) -> (u8, bool) {
    let carry = num1 as u16 + num2 as u16 + carry as u16 > u8::MAX as u16;
    let sum = num1.wrapping_add(num2);
    (sum, carry)
}
#[unsafe(no_mangle)]
pub fn myadd1(num: [u8;16], num2: [u8;16]) -> [u8;16] {
    let mut carry = false;
    let mut sum = [0u8;16];
    for i in 0..num.len() {
        let (sum_v, carry_v) = adc(num[i], num2[i], carry);
        sum[i] = sum_v;
        carry = carry_v;
    }
    sum
}
#[repr(C)]
struct AlignAs<A,T>(T,[A;0]);
#[unsafe(no_mangle)]
pub fn myadd2(num: AlignAs<u128, [u8;16]>, num2: AlignAs<u128, [u8;16]>) -> AlignAs<u128, [u8;16]> {
    let mut carry = false;
    let mut sum = AlignAs([0u8;16],[]);
    for i in 0..num.0.len() {
        let (sum_v, carry_v) = adc(num.0[i], num2.0[i], carry);
        sum.0[i] = sum_v;
        carry = carry_v;
    }
    sum
}

#[unsafe(no_mangle)]
pub fn myadd3(num: [u8;16], num2: [u8;16]) -> [u8;16] {
    let mut carry = false;
    let mut sum = [0u8;16];
    for i in 0..num.len() {
        let (sum_v, carry_v) = u8::carrying_add(num[i], num2[i], carry);
        sum[i] = sum_v;
        carry = carry_v;
    }
    sum
}

copy of asm
a:
        mov     eax, esi
        add     al, dil
        setb    cl
        add     al, dl
        setb    dl
        or      dl, cl
        ret

b:
        movzx   ecx, dil
        movzx   eax, sil
        add     edx, eax
        add     edx, ecx
        cmp     edx, 256
        setae   dl
        add     al, cl
        ret

myadd:
        mov     rax, rdi
        add     rax, rdx
        adc     rsi, rcx
        mov     rdx, rsi
        ret

myadd1:
        mov     rax, rdi
        movdqu  xmm0, xmmword ptr [rsi]
        movdqu  xmm1, xmmword ptr [rdx]
        paddb   xmm1, xmm0
        movdqu  xmmword ptr [rdi], xmm1
        ret

myadd2:
        mov     rax, rdi
        movdqa  xmm0, xmmword ptr [rdx]
        paddb   xmm0, xmmword ptr [rsi]
        movdqa  xmmword ptr [rdi], xmm0
        ret

myadd3:
        push    rbp
        push    r14
        push    rbx
        mov     rax, rdi
        movq    xmm0, qword ptr [rsi]
        movq    xmm1, qword ptr [rdx]
        paddb   xmm1, xmm0
        movdqa  xmmword ptr [rsp - 16], xmm1
        movzx   ecx, byte ptr [rsp - 16]
        mov     byte ptr [rdi], cl
        pmaxub  xmm0, xmm1
        pcmpeqb xmm0, xmm1
        pmovmskb        r10d, xmm0
        not     r10d
        mov     ecx, r10d
        shr     cl, 7
        mov     edi, r10d
        and     dil, 64
        shr     dil, 6
        mov     r8d, r10d
        and     r8b, 32
        shr     r8b, 5
        mov     r9d, r10d
        and     r9b, 16
        shr     r9b, 4
        mov     r11d, r10d
        and     r11b, 8
        shr     r11b, 3
        mov     ebx, r10d
        and     bl, 4
        shr     bl, 2
        mov     ebp, r10d
        and     bpl, 2
        shr     bpl
        and     r10b, 1
        add     r10b, byte ptr [rsp - 15]
        setb    r14b
        or      r14b, bpl
        mov     byte ptr [rax + 1], r10b
        add     r14b, byte ptr [rsp - 14]
        setb    r10b
        or      r10b, bl
        mov     byte ptr [rax + 2], r14b
        add     r10b, byte ptr [rsp - 13]
        setb    bl
        or      bl, r11b
        mov     byte ptr [rax + 3], r10b
        add     bl, byte ptr [rsp - 12]
        setb    r10b
        or      r10b, r9b
        mov     byte ptr [rax + 4], bl
        add     r10b, byte ptr [rsp - 11]
        setb    r9b
        or      r9b, r8b
        mov     byte ptr [rax + 5], r10b
        add     r9b, byte ptr [rsp - 10]
        setb    r8b
        or      r8b, dil
        mov     byte ptr [rax + 6], r9b
        add     r8b, byte ptr [rsp - 9]
        setb    dil
        or      dil, cl
        mov     byte ptr [rax + 7], r8b
        movzx   ecx, byte ptr [rdx + 8]
        add     dil, -1
        adc     cl, byte ptr [rsi + 8]
        mov     byte ptr [rax + 8], cl
        movzx   ecx, byte ptr [rdx + 9]
        adc     cl, byte ptr [rsi + 9]
        movzx   edi, byte ptr [rdx + 10]
        adc     dil, byte ptr [rsi + 10]
        movzx   r8d, byte ptr [rdx + 11]
        adc     r8b, byte ptr [rsi + 11]
        mov     byte ptr [rax + 9], cl
        mov     byte ptr [rax + 10], dil
        mov     byte ptr [rax + 11], r8b
        movzx   ecx, byte ptr [rdx + 12]
        adc     cl, byte ptr [rsi + 12]
        mov     byte ptr [rax + 12], cl
        movzx   ecx, byte ptr [rdx + 13]
        adc     cl, byte ptr [rsi + 13]
        movzx   edi, byte ptr [rdx + 14]
        adc     dil, byte ptr [rsi + 14]
        movzx   edx, byte ptr [rdx + 15]
        adc     dl, byte ptr [rsi + 15]
        mov     byte ptr [rax + 13], cl
        mov     byte ptr [rax + 14], dil
        mov     byte ptr [rax + 15], dl
        pop     rbx
        pop     r14
        pop     rbp
        ret

Your adc function isn't correct. It needs to add the incoming carry into the sum. The myadd2 assembly is a hint; paddb doesn't propagate carries.

Why are you adding byte-by-byte and not (quad)word-by-word?

You're right. I forgot to do that. The naive version is now also terrible.

new adc and assembly
fn adc(num1: u8, num2: u8, carry: bool) -> (u8, bool) {
    let carry = num1 as u16 + num2 as u16 + carry as u16 > u8::MAX as u16;
    let sum = num1.wrapping_add(num2).wrapping_add(carry as u8);
    (sum, carry)
}
a:
        mov     eax, esi
        add     al, dil
        setb    cl
        add     al, dl
        setb    dl
        or      dl, cl
        ret

b:
        movzx   ecx, dil
        movzx   eax, sil
        add     edx, eax
        add     edx, ecx
        cmp     edx, 256
        setae   dl
        add     al, cl
        add     al, dl
        ret

myadd:
        mov     rax, rdi
        add     rax, rdx
        adc     rsi, rcx
        mov     rdx, rsi
        ret

myadd1:
        push    rbp
        push    r14
        push    rbx
        movq    xmm0, qword ptr [rsi]
        movdqa  xmmword ptr [rsp - 48], xmm0
        movzx   eax, byte ptr [rsp - 48]
        movq    xmm1, qword ptr [rdx]
        movdqa  xmmword ptr [rsp - 64], xmm1
        movzx   ecx, byte ptr [rsp - 64]
        add     ecx, eax
        xor     eax, eax
        cmp     ecx, 256
        setae   al
        movzx   ecx, byte ptr [rsp - 47]
        movzx   r8d, byte ptr [rsp - 63]
        add     r8d, ecx
        add     r8d, eax
        xor     ecx, ecx
        cmp     r8d, 256
        setae   cl
        movzx   r8d, byte ptr [rsp - 46]
        movzx   r9d, byte ptr [rsp - 62]
        add     r9d, r8d
        add     r9d, ecx
        xor     r8d, r8d
        cmp     r9d, 256
        setae   r8b
        movzx   r9d, byte ptr [rsp - 45]
        movzx   r10d, byte ptr [rsp - 61]
        add     r10d, r9d
        add     r10d, r8d
        xor     r9d, r9d
        cmp     r10d, 256
        setae   r9b
        movzx   r10d, byte ptr [rsp - 44]
        movzx   r11d, byte ptr [rsp - 60]
        add     r11d, r10d
        add     r11d, r9d
        xor     ebx, ebx
        cmp     r11d, 256
        setae   bl
        movzx   r10d, byte ptr [rsp - 43]
        movzx   ebp, byte ptr [rsp - 59]
        add     ebp, r10d
        add     ebp, ebx
        xor     r11d, r11d
        cmp     ebp, 256
        setae   r11b
        movzx   r10d, byte ptr [rsp - 42]
        movzx   r14d, byte ptr [rsp - 58]
        add     r14d, r10d
        add     r14d, r11d
        xor     ebp, ebp
        cmp     r14d, 256
        setae   bpl
        movzx   r10d, byte ptr [rsp - 41]
        movzx   r14d, byte ptr [rsp - 57]
        add     r14d, r10d
        add     r14d, ebp
        xor     r10d, r10d
        cmp     r14d, 256
        setae   r10b
        shl     r11d, 8
        or      r11d, ebx
        shl     ecx, 8
        or      ecx, eax
        shl     r8d, 16
        or      r8d, ecx
        shl     r9d, 24
        or      r9d, r8d
        movd    xmm2, r9d
        pinsrw  xmm2, r11d, 2
        mov     rax, rdi
        mov     ecx, r10d
        shl     ecx, 8
        or      ecx, ebp
        pinsrw  xmm2, ecx, 3
        paddb   xmm1, xmm0
        paddb   xmm1, xmm2
        movq    qword ptr [rdi], xmm1
        movq    xmm0, qword ptr [rsi + 8]
        movdqa  xmmword ptr [rsp - 16], xmm0
        movzx   ecx, byte ptr [rsp - 16]
        movq    xmm1, qword ptr [rdx + 8]
        movdqa  xmmword ptr [rsp - 32], xmm1
        movzx   edx, byte ptr [rsp - 32]
        add     edx, ecx
        add     edx, r10d
        xor     ecx, ecx
        cmp     edx, 256
        setae   cl
        movzx   edx, byte ptr [rsp - 15]
        movzx   esi, byte ptr [rsp - 31]
        add     esi, edx
        add     esi, ecx
        xor     edx, edx
        cmp     esi, 256
        setae   dl
        movzx   esi, byte ptr [rsp - 14]
        movzx   edi, byte ptr [rsp - 30]
        add     edi, esi
        add     edi, edx
        xor     esi, esi
        cmp     edi, 256
        setae   sil
        movzx   edi, byte ptr [rsp - 13]
        movzx   r8d, byte ptr [rsp - 29]
        add     r8d, edi
        add     r8d, esi
        xor     edi, edi
        cmp     r8d, 256
        setae   dil
        movzx   r8d, byte ptr [rsp - 12]
        movzx   r10d, byte ptr [rsp - 28]
        add     r10d, r8d
        add     r10d, edi
        xor     r9d, r9d
        cmp     r10d, 256
        setae   r9b
        movzx   r8d, byte ptr [rsp - 11]
        movzx   r10d, byte ptr [rsp - 27]
        add     r10d, r8d
        add     r10d, r9d
        xor     r8d, r8d
        cmp     r10d, 256
        setae   r8b
        movzx   r10d, byte ptr [rsp - 10]
        movzx   r11d, byte ptr [rsp - 26]
        add     r11d, r10d
        add     r11d, r8d
        xor     r10d, r10d
        cmp     r11d, 256
        setae   r10b
        movzx   r11d, byte ptr [rsp - 9]
        movzx   ebx, byte ptr [rsp - 25]
        add     ebx, r11d
        add     ebx, r10d
        xor     r11d, r11d
        cmp     ebx, 256
        setae   r11b
        shl     r8d, 8
        or      r8d, r9d
        shl     edx, 8
        or      edx, ecx
        shl     esi, 16
        or      esi, edx
        shl     edi, 24
        or      edi, esi
        movd    xmm2, edi
        pinsrw  xmm2, r8d, 2
        shl     r11d, 8
        or      r11d, r10d
        pinsrw  xmm2, r11d, 3
        paddb   xmm1, xmm0
        paddb   xmm1, xmm2
        movq    qword ptr [rax + 8], xmm1
        pop     rbx
        pop     r14
        pop     rbp
        ret

myadd2:
        push    rbp
        push    r14
        push    rbx
        movq    xmm0, qword ptr [rsi]
        movdqa  xmmword ptr [rsp - 48], xmm0
        movzx   eax, byte ptr [rsp - 48]
        movq    xmm1, qword ptr [rdx]
        movdqa  xmmword ptr [rsp - 64], xmm1
        movzx   ecx, byte ptr [rsp - 64]
        add     ecx, eax
        xor     eax, eax
        cmp     ecx, 256
        setae   al
        movzx   ecx, byte ptr [rsp - 47]
        movzx   r8d, byte ptr [rsp - 63]
        add     r8d, ecx
        add     r8d, eax
        xor     ecx, ecx
        cmp     r8d, 256
        setae   cl
        movzx   r8d, byte ptr [rsp - 46]
        movzx   r9d, byte ptr [rsp - 62]
        add     r9d, r8d
        add     r9d, ecx
        xor     r8d, r8d
        cmp     r9d, 256
        setae   r8b
        movzx   r9d, byte ptr [rsp - 45]
        movzx   r10d, byte ptr [rsp - 61]
        add     r10d, r9d
        add     r10d, r8d
        xor     r9d, r9d
        cmp     r10d, 256
        setae   r9b
        movzx   r10d, byte ptr [rsp - 44]
        movzx   r11d, byte ptr [rsp - 60]
        add     r11d, r10d
        add     r11d, r9d
        xor     ebx, ebx
        cmp     r11d, 256
        setae   bl
        movzx   r10d, byte ptr [rsp - 43]
        movzx   ebp, byte ptr [rsp - 59]
        add     ebp, r10d
        add     ebp, ebx
        xor     r11d, r11d
        cmp     ebp, 256
        setae   r11b
        movzx   r10d, byte ptr [rsp - 42]
        movzx   r14d, byte ptr [rsp - 58]
        add     r14d, r10d
        add     r14d, r11d
        xor     ebp, ebp
        cmp     r14d, 256
        setae   bpl
        movzx   r10d, byte ptr [rsp - 41]
        movzx   r14d, byte ptr [rsp - 57]
        add     r14d, r10d
        add     r14d, ebp
        xor     r10d, r10d
        cmp     r14d, 256
        setae   r10b
        shl     r11d, 8
        or      r11d, ebx
        shl     ecx, 8
        or      ecx, eax
        shl     r8d, 16
        or      r8d, ecx
        shl     r9d, 24
        or      r9d, r8d
        movd    xmm2, r9d
        pinsrw  xmm2, r11d, 2
        mov     rax, rdi
        mov     ecx, r10d
        shl     ecx, 8
        or      ecx, ebp
        pinsrw  xmm2, ecx, 3
        paddb   xmm1, xmm0
        paddb   xmm1, xmm2
        movq    qword ptr [rdi], xmm1
        movq    xmm0, qword ptr [rsi + 8]
        movdqa  xmmword ptr [rsp - 16], xmm0
        movzx   ecx, byte ptr [rsp - 16]
        movq    xmm1, qword ptr [rdx + 8]
        movdqa  xmmword ptr [rsp - 32], xmm1
        movzx   edx, byte ptr [rsp - 32]
        add     edx, ecx
        add     edx, r10d
        xor     ecx, ecx
        cmp     edx, 256
        setae   cl
        movzx   edx, byte ptr [rsp - 15]
        movzx   esi, byte ptr [rsp - 31]
        add     esi, edx
        add     esi, ecx
        xor     edx, edx
        cmp     esi, 256
        setae   dl
        movzx   esi, byte ptr [rsp - 14]
        movzx   edi, byte ptr [rsp - 30]
        add     edi, esi
        add     edi, edx
        xor     esi, esi
        cmp     edi, 256
        setae   sil
        movzx   edi, byte ptr [rsp - 13]
        movzx   r8d, byte ptr [rsp - 29]
        add     r8d, edi
        add     r8d, esi
        xor     edi, edi
        cmp     r8d, 256
        setae   dil
        movzx   r8d, byte ptr [rsp - 12]
        movzx   r10d, byte ptr [rsp - 28]
        add     r10d, r8d
        add     r10d, edi
        xor     r9d, r9d
        cmp     r10d, 256
        setae   r9b
        movzx   r8d, byte ptr [rsp - 11]
        movzx   r10d, byte ptr [rsp - 27]
        add     r10d, r8d
        add     r10d, r9d
        xor     r8d, r8d
        cmp     r10d, 256
        setae   r8b
        movzx   r10d, byte ptr [rsp - 10]
        movzx   r11d, byte ptr [rsp - 26]
        add     r11d, r10d
        add     r11d, r8d
        xor     r10d, r10d
        cmp     r11d, 256
        setae   r10b
        movzx   r11d, byte ptr [rsp - 9]
        movzx   ebx, byte ptr [rsp - 25]
        add     ebx, r11d
        add     ebx, r10d
        xor     r11d, r11d
        cmp     ebx, 256
        setae   r11b
        shl     r8d, 8
        or      r8d, r9d
        shl     edx, 8
        or      edx, ecx
        shl     esi, 16
        or      esi, edx
        shl     edi, 24
        or      edi, esi
        movd    xmm2, edi
        pinsrw  xmm2, r8d, 2
        shl     r11d, 8
        or      r11d, r10d
        pinsrw  xmm2, r11d, 3
        paddb   xmm1, xmm0
        paddb   xmm1, xmm2
        movq    qword ptr [rax + 8], xmm1
        pop     rbx
        pop     r14
        pop     rbp
        ret

myadd3:
        push    rbp
        push    r14
        push    rbx
        mov     rax, rdi
        movq    xmm0, qword ptr [rsi]
        movq    xmm1, qword ptr [rdx]
        paddb   xmm1, xmm0
        movdqa  xmmword ptr [rsp - 16], xmm1
        movzx   ecx, byte ptr [rsp - 16]
        mov     byte ptr [rdi], cl
        pmaxub  xmm0, xmm1
        pcmpeqb xmm0, xmm1
        pmovmskb        r10d, xmm0
        not     r10d
        mov     ecx, r10d
        shr     cl, 7
        mov     edi, r10d
        and     dil, 64
        shr     dil, 6
        mov     r8d, r10d
        and     r8b, 32
        shr     r8b, 5
        mov     r9d, r10d
        and     r9b, 16
        shr     r9b, 4
        mov     r11d, r10d
        and     r11b, 8
        shr     r11b, 3
        mov     ebx, r10d
        and     bl, 4
        shr     bl, 2
        mov     ebp, r10d
        and     bpl, 2
        shr     bpl
        and     r10b, 1
        add     r10b, byte ptr [rsp - 15]
        setb    r14b
        or      r14b, bpl
        mov     byte ptr [rax + 1], r10b
        add     r14b, byte ptr [rsp - 14]
        setb    r10b
        or      r10b, bl
        mov     byte ptr [rax + 2], r14b
        add     r10b, byte ptr [rsp - 13]
        setb    bl
        or      bl, r11b
        mov     byte ptr [rax + 3], r10b
        add     bl, byte ptr [rsp - 12]
        setb    r10b
        or      r10b, r9b
        mov     byte ptr [rax + 4], bl
        add     r10b, byte ptr [rsp - 11]
        setb    r9b
        or      r9b, r8b
        mov     byte ptr [rax + 5], r10b
        add     r9b, byte ptr [rsp - 10]
        setb    r8b
        or      r8b, dil
        mov     byte ptr [rax + 6], r9b
        add     r8b, byte ptr [rsp - 9]
        setb    dil
        or      dil, cl
        mov     byte ptr [rax + 7], r8b
        movzx   ecx, byte ptr [rdx + 8]
        add     dil, -1
        adc     cl, byte ptr [rsi + 8]
        mov     byte ptr [rax + 8], cl
        movzx   ecx, byte ptr [rdx + 9]
        adc     cl, byte ptr [rsi + 9]
        movzx   edi, byte ptr [rdx + 10]
        adc     dil, byte ptr [rsi + 10]
        movzx   r8d, byte ptr [rdx + 11]
        adc     r8b, byte ptr [rsi + 11]
        mov     byte ptr [rax + 9], cl
        mov     byte ptr [rax + 10], dil
        mov     byte ptr [rax + 11], r8b
        movzx   ecx, byte ptr [rdx + 12]
        adc     cl, byte ptr [rsi + 12]
        mov     byte ptr [rax + 12], cl
        movzx   ecx, byte ptr [rdx + 13]
        adc     cl, byte ptr [rsi + 13]
        movzx   edi, byte ptr [rdx + 14]
        adc     dil, byte ptr [rsi + 14]
        movzx   edx, byte ptr [rdx + 15]
        adc     dl, byte ptr [rsi + 15]
        mov     byte ptr [rax + 13], cl
        mov     byte ptr [rax + 14], dil
        mov     byte ptr [rax + 15], dl
        pop     rbx
        pop     r14
        pop     rbp
        ret

Why is the addition on an array of u8 so much worse than u128? Is this just something llvm doesn't know how to optimize and I should not have expected better?

The reason I originally had [u8;N] instead of [u64;N] is because the implementation I was writing this for often has N % 8 !=0 and N <= 8. I hoped that the u8::carry_add would optimize well, so I didn't have to dispatch multiple implementations depending on N but it doesn't look like it.

You should use u64 limbs and just round N up to the nearest multiple of 8. There's no point trying to take advantage of weird-sized numbers.

The source of this definitely benefits from the compression of not padding to nearest multiple of 8 since there are a lot of these being stored. I suppose you're saying I should pad before adding and truncate after.

Using u64's looks much better, though still slightly worse. Is it unreasonable to expect this to come out the same?

u64
#[unsafe(no_mangle)]
pub fn myadd5(num: AlignAs<u128,[u64;2]>, num2: AlignAs<u128,[u64;2]>) -> AlignAs<u128,[u64;2]> {
    let mut carry = false;
    let mut sum = AlignAs([0u64;2],[]);
    for i in 0..num.0.len() {
        let (sum_v, carry_v) = u64::carrying_add(num.0[i], num2.0[i], carry);
        sum.0[i] = sum_v;
        carry = carry_v;
    }
    sum
}
myadd5:
        mov     rcx, qword ptr [rdx]
        add     rcx, qword ptr [rsi]
        mov     rdx, qword ptr [rdx + 8]
        adc     rdx, qword ptr [rsi + 8]
        mov     rax, rdi
        mov     qword ptr [rdi], rcx
        mov     qword ptr [rdi + 8], rdx
        ret

vs

myadd:
        mov     rax, rdi
        add     rax, rdx
        adc     rsi, rcx
        mov     rdx, rsi
        ret

The difference between your functions at this point just has to do with the ABI. In one case the arguments are passed in memory, the other in registers.

Why is the ABI different here?

Note that rustc has a codegen test to ensure that you get adc, so you might want to compare whatever it is that you're doing to what that test is doing:

There doesn't seem to be any real reason, other than the fact that the development process never seems to be able to move forward and pin anything down.

#[unsafe(no_mangle)]
pub extern "C" fn myadd5(num: AlignAs<u128,[u64;2]>, num2: AlignAs<u128,[u64;2]>) -> AlignAs<u128,[u64;2]> {
    let mut carry = false;
    let mut sum = AlignAs([0u64;2],[]);
    for i in 0..num.0.len() {
        let (sum_v, carry_v) = u64::carrying_add(num.0[i], num2.0[i], carry);
        sum.0[i] = sum_v;
        carry = carry_v;
    }
    sum
}

Becomes:

myadd5:
        mov     rax, rdi
        add     rax, rdx
        adc     rsi, rcx
        mov     rdx, rsi
        ret
1 Like

The original code is different because it used arrays of u8 while that test is with a u64. I guess that there isn't an optimization when the [u8;] array is a nice length.