Review request: stack bleaching

I have a program which creates threads with small stack sizes which are set using thread::Builder::stack_size. Right now it targets only Linux x86-64 targets.

The program looks roughly like this:

/// Perform sensitive computations
#[inline(never)]
fn process_cmd(cmd: Cmd) { ... }

for conn in get_connection() {
    thread::Builder::new()
        .stack_size(stack_size)
        .spawn(move || {
            for cmd in get_cmd(&mut conn) {
                process_cmd(cmd);
                bleach_stack(stack_size);
            }
        })?;
}

process_cmd performs sensitive computations (e.g. cryptography stuff), so I would like to erase stack immediately after sensitive work is done (see this thread for an example why it can be important). stack_size is multiple of 4 KiB and bigger or equal to 16 KiB (see PTHREAD_STACK_MIN).

My current implementation of bleach_stack looks like this:

fn bleach_stack(stack_size: u64) {
    unsafe {
        std::arch::asm!(
            // Copy stack pointer to a scratch register and
            // align it to 16 bytes
            "mov {pos}, rsp",
            "and {pos}, -16",

            // Compute stack end, assuming current usage of stack is
            // less than one page and one pages is used for stuff like TLS.
            "mov {end}, {pos}",
            "and {end}, -4096",
            "sub {end}, {stack_size}",
            "add {end}, 2 * 4096",

            "xorps   xmm0, xmm0",
            "movups  xmmword ptr [rsp - 8], xmm0",
            
            "42:",
            "sub {pos}, 16",
            "movaps  xmmword ptr [{pos}], xmm0",
            "cmp {pos}, {end}",
            "jne 42b",

            stack_size = in(reg) stack_size,
            pos = out(reg) _,
            end = out(reg) _,
        );
    }
}

Is this implementation correct? Is it acceptable to erase thread's stack frame until its end? I do not use stack canaries and IIUC set_task_stack_end_magic is not relevant to user-spawned threads. And in my understanding the code should be safe in regards of context switching, since saved context is pushed relative to the RSP register, which I do not modify.

Also do you know why I have to use add {end}, 2 * 4096 instead of add {end}, 4096? Using the latter results in stack overflow, same with add {end}, 2 * 4096 + 16. Is it because the first page of thread's stack used for stuff like TLS?

Note: I know that moving one XMM value per iteration is not the most efficient approach, but it's not important for the above questions.

Have not looked at the code, just wanted to drop a link to sanitizer - The Rust Unstable Book which describes tools that might help you verify whether the code is doing something potentially harmful.

Your code clears the stack below rsp but on Linux something like 128 bytes is reserved in that space for a "red zone." Also your bleach_stack function is very far from safe because the user can put anything they want in for stack_size

You should probably implement this as a guard which bleaches the stack on drop. As it is, you won't clear the stack if the function panics.

1 Like

Only leaf functions can put something into red zone without bumping stack pointer and the stack bleaching is done in a leaf function. I guess, to be extra safe it's better to mark it #[inline(never].

You are right, but I use catch_unwind instead of Drop guard. I've omitted it from OP for simplicity.

1 Like

Unless you use options(nostack) the compiler assumes that you may call a function and as such has to assume that the red zone is clobbered (or more likely not use the red zone at all).

One issue with the current code is the compiler would be allowed to hoist process_cmd(cmd) after bleach_stack(stack_size) if it could prove that process_cmd(cmd) doesn't touch anything bleach_stack(stack_size) would be allowed to touch. A fix for that would be to call process_cmd(cmd) from within the asm block of bleach_stack(stack_size) before doing the stack bleaching.

The closure and bleach_stack function together with all the libstd internal functions that run before calling your closure consume stack space too. The exact amount is not stable, but in this case clearly more than 4096 bytes.

2 Likes

I am not sure about that. Running the following code:

fn f() -> u64 {
    let mut sp: u64;
    unsafe {
        std::arch::asm!(
            "mov {sp}, rsp",
            sp = out(reg) sp,
        );
    }
    sp
}

let sp = std::thread::Builder::new()
    .stack_size(1 << 15)
    .spawn(f)?
    .join()?;
// stack grows down
println!("{}", 4096 - (sp % 4096));

Prints 528. This number is close enough to 0. I doubt that libstd functions consume almost exactly 4 KiB. Either way, it's a surprising amount of memory to consume upfront for setting up a thread.

UPD: Compiling the above code for x86_64-unknown-linux-musl and running it prints 1584. Also, stack bleaching now works with "add {end}, 4096". So it looks like libpthread consumes approximately one page for setting up a thread.

1 Like

I found a more robust way of finding end of thread's stack frame using pthread functions:

let mut stackaddr = core::ptr::null_mut();
let mut stacksize: libc::size_t = 0;
let mut attr = mem::MaybeUninit::uninit();
let ret = unsafe {
    libc::pthread_getattr_np(libc::pthread_self(), attr.as_mut_ptr())
};
if ret != 0 {
    return Err(...);
}
let ret = unsafe {
    libc::pthread_attr_getstack(
        attr.as_ptr(),
        &mut stackaddr,
        &mut stacksize,
    )
};
if ret != 0 {
    return Err(...);
}
if stackaddr.align_offset(16) != 0 {
    return Err(...);
}

After additional sanity check to ensure that stack pointer is indeed inside stackaddr..stackaddr+stacksize, I can use stackaddr in the bleaching function:

pub unsafe fn bleach_stack(stackaddr: *mut c_void) {
    std::arch::asm!(
        "mov {pos}, rsp",
        "and {pos}, -16",
        "xorps   xmm0, xmm0",
        "movups  xmmword ptr [rsp - 8], xmm0",
        "42:",
        "sub {pos}, 16",
        "movaps  xmmword ptr [{pos}], xmm0",
        "cmp {pos}, {stackaddr}",
        "jne 42b",
        stackaddr = in(reg) stackaddr,
        pos = out(reg) _,
    );
}