Saving 'last executed' line of code in context switching in green threads

#![feature(naked_functions)]
use std::arch::{asm, naked_asm};

const DEFAULT_STACK_SIZE: usize = 1024 * 1024 * 2;
const MAX_THREADS: usize = 4;
static mut RUNTIME: usize = 0;

pub struct Runtime {
    threads: Vec<Thread>,
    current: usize,
}

struct Thread {
    stack: Vec<u8>,
    ctx: ThreadContext,
    state: State,
}

#[derive(Debug, PartialEq, Eq)]
enum State {
    Available,
    Running,
    Ready,
}

#[derive(Debug, Default)]
#[repr(C)]
struct ThreadContext {
    rsp: u64,
    r15: u64,
    r14: u64,
    r13: u64,
    r12: u64,
    rbx: u64,
    rbp: u64,
}

impl Thread {
    fn new() -> Self {
        Thread {
            stack: vec![0u8; DEFAULT_STACK_SIZE],
            ctx: ThreadContext::default(),
            state: State::Available,
        }
    }
}

impl Runtime {
    pub fn new() -> Self {
        let base_thread = Thread {
            stack: vec![0u8; DEFAULT_STACK_SIZE],
            ctx: ThreadContext::default(),
            state: State::Running,
        };

        let mut threads = vec![base_thread];

        let mut available_threads: Vec<Thread> = (1..MAX_THREADS).map(|_| Thread::new()).collect();

        threads.append(&mut available_threads);

        Runtime {
            threads,
            current: 0,
        }
    }

    pub fn init(&self) {
        unsafe {
            let r_ptr: *const Runtime = self;
            RUNTIME = r_ptr as usize;
        }
    }

    pub fn run(&mut self) -> ! {
        while self.t_yield() {}
        std::process::exit(0)
    }

    fn t_return(&mut self) {
        if self.current != 0 {
            self.threads[self.current].state = State::Available;
            self.t_yield();
        }
    }

    #[inline(never)]
    fn t_yield(&mut self) -> bool {
        let mut pos = self.current;
        while self.threads[pos].state != State::Ready {
            pos += 1;
            if pos == self.threads.len() {
                pos = 0;
            }

            if pos == self.current {
                return false;
            }
        }

        if self.threads[self.current].state != State::Available {
            self.threads[self.current].state = State::Ready;
        }

        self.threads[pos].state = State::Running;
        let old_pos = self.current;
        self.current = pos;

        unsafe {
            let old: *mut ThreadContext = &mut self.threads[old_pos].ctx;
            let new: *const ThreadContext = &self.threads[pos].ctx;
            asm!("call switch", in("rdi") old, in("rsi") new, clobber_abi("C"));
        }

        if self.current == 1 {
            println!("{:?}", self.threads[self.current].stack)
        }

        self.threads.len() > 0
    }

    pub fn spawn(&mut self, f: fn()) {
        let available = self
            .threads
            .iter_mut()
            .find(|t| t.state == State::Available)
            .expect("no available threads");

        let size = available.stack.len();

        unsafe {
            let s_ptr = available.stack.as_mut_ptr().offset(size as isize);
            let s_ptr = (s_ptr as usize & !15) as *mut u8;
            std::ptr::write(s_ptr.offset(-16) as *mut u64, guard as u64);
            std::ptr::write(s_ptr.offset(-24) as *mut u64, skip as u64);
            std::ptr::write(s_ptr.offset(-32) as *mut u64, f as u64);
            available.ctx.rsp = s_ptr.offset(-32) as u64;
        }

        available.state = State::Ready;
    }
}

fn guard() {
    unsafe {
        let rt_ptr = RUNTIME as *mut Runtime;
        (*rt_ptr).t_return();
    }
}

#[naked]
unsafe extern "C" fn skip() {
    unsafe { naked_asm!("ret") }
}

pub fn yield_thread() {
    unsafe {
        let rt_ptr = RUNTIME as *mut Runtime;
        (*rt_ptr).t_yield();
    }
}

#[naked]
#[unsafe(no_mangle)]
unsafe extern "C" fn switch() {
    unsafe {
        naked_asm!(
            "mov [rdi + 0x00], rsp",
            "mov [rdi + 0x08], r15",
            "mov [rdi + 0x10], r14",
            "mov [rdi + 0x18], r13",
            "mov [rdi + 0x20], r12",
            "mov [rdi + 0x28], rbx",
            "mov [rdi + 0x30], rbp",
            "mov rsp, [rsi + 0x00]",
            "mov r15, [rsi + 0x08]",
            "mov r14, [rsi + 0x10]",
            "mov r13, [rsi + 0x18]",
            "mov r12, [rsi + 0x20]",
            "mov rbx, [rsi + 0x28]",
            "mov rbp, [rsi + 0x30]",
            "ret",
        )
    }
}

fn main() {
    let mut rt = Runtime::new();
    rt.init();
    rt.spawn(|| {
        println!("THREAD 1 STARTING");

        let id = 1;
        for i in 0..2 {
            println!("thread: {} counter: {}", id, i);
            yield_thread();
        }
        println!("THREAD 1 FINISHED");
    });
    rt.spawn(|| {
        println!("THREAD 2 STARTING");

        let id = 2;
        for i in 0..5 {
            println!("thread: {} counter: {}", id, i);
            yield_thread();
        }
        println!("THREAD 2 FINISHED");
    });

    rt.run();
}

I can't wrap my head around 2 things:

Q1: how does CPU knows where is the last executed line in previous thread so it can return there instead of executing from start to finish again?

Q2: do i need manually initialize any of registers in ThreadContext beside rsp? I printed them in console and got a bunch of zeroes. Do CPU fills them itself?

Never mind lines of source code you need to know the last executed machine instruction. On x86 the instructions being executed are pointed at by the IP or RIP register.

When you switch context you need to save the RIP somewhere. I think typically one pushes all the current threads registers onto the threads stack, including the RIP, then you can restore the stack and do a return which will pop the RIP and continue execution from there.

You might be interested in the way tsoding did this in C. He live codes it on YouTube here: https://www.youtube.com/watch?v=uFET2vifHh4 and his earlier video on coroutines here: https://www.youtube.com/watch?v=sYSP_elDdZw&t=576s

Does it mean that rip register is initialized automatically and placed on stack? Does it happen because of call switch assembly call? If it just remember very last instruction or there is some instruction to remember last instruction before that?

It's a long time since I did this kind of thing so I'm a bit rusty.

But in a cooperative scheduler a thread that wants to give up control will call some kind of yield() function. We called it suspend() back in the day.

So inside your yield() function the instruction pointer (RIP) will already be on the stack, having been placed there by the CALL instruction that got you to yield().

So then the yield()has to switch the stack pointer of to the stack of the next thread to be run and perform a return (RET) which will pop the saved instruction pointer off the stack and jump back to where ever that thread was running.

Typically on does initialise IP. It gets set to some start address on processor reset and is then it is changed as calls, returns and jumps are made. I have a felling one cannot even do a MOV into the IP register.

They are green threads. Do they have call instruction? In rust discord they told me that i just call yield_thread() and it executes anything it switches on and then it returns like a normal function and continue its flow.
But it is just another source of confusion. Don't i change flow of code when i'm switching pointer to stacks? It means that i jump to execute another instruction and never return because there is no mechanism of "unwinding" that. So i'm just in a loop changing stacks. In my example: main thread -> 1 thread -> 2 thread -> main thread ...

And so explanation that they gave me now wrong because i never return from yield_thread function. Or do i?

Where did you get the code from?

All threads are at same point when the switch takes place in t_yield(). The point of difference is when t_yield() returns and this is where location is read from rsp

The code is from Asynchronous Programming in Rust book.

What do you mean by

Studying the backtraces might help you (although not pretty.)

Found. looks like Windows has special treatment.

Oh that playground code looks very useful. Gonna check it out.

I'm running linux so this example must work correct.

Thank you for help

As far as I can tell the only way for a tread to suspend itself and allow another thread to be scheduled is to call yeild(). That means that at the point the actual context switch happens all threads will be at the same line in the source code, somewhere in the yield() function. When that yield() function returns the RET happens in a new thread context with a new stack and register values (whatever a thread context contains). And so the return from yield() may well get you back to a different line in the source code, whatever the new thread was running when it called yield().

Let's say you have three threads:

fn main() {
    spawn(a);
    spawn(b);
    for i in 0.. {
        println!("main {i}");
        yield();
    }
}

fn a() {
    for i in 0.. {
        println!("a {i}");
        yield();
    } 
}

fn b() {
    for i in 0.. {
        println!("b {i}");
        yield();
    } 
}

Assuming a primitive scheduler that just switches threads in a loop, this will execute like this:

Main prints 0 and calls yield
A starts executing, prints 0 and calls yield
B starts executing, prints 0 and calls yield

Main returns from yield, prints 1, yields
A returns from yield, prints 1, yields
B returns from yield, prints 1, yields

How this work is that when yield is called, it it updates current stack and instruction pointers for this task inside the scheduler, and switches to the next task. This switch is pretty much just loading SP and RP for the other the other task into rsp and rip and then... just returning from yield.

So you saying It's not my job to remember last place where it yielded from as it is stored in register in stack of the thread when switch is called?

I have not got my head into the details but as far as I can tell code executes in the context of the currently running thread until it hits the ret at the end of the switch()function. Exactly at that ret instruction is run is where the jump from one threads code to the next is made.

At that point ret instruction will pop a return address off the stack and jump to it. Thus jumping you to the code of the next tread.

That return address is hopefully put in place by the mov instructions above the ret which are saving the registers of the current thread into wherever rdi is pointing and loading the registers of the next thread from wherever rsiis pointing. rdi and rsi are presumably pointing at the current and next threads ThreadContext structures at that point.

I did a bit of research with chat gpt + articles that it gave me. From what they said: control flow instructions like jump, call and etc will get and save next instruction in rip register, push it to current stack
In my case if rsp was changed and ret is called after it, ret will change code execution on whatever stack i changed rsp. So as call pushed rip to my old stack, when i return to it again rsp will point to rip register, as now it is top of old stack, pop it and jump to according instruction.

Is that correct?
If so my first thought about some magic behind call was right (not in this post, but when i first encoutered with this problem)

ChatGTP lies all the time...

Jump/branch instructions, conditional jump/branch instructions do not save any return address (next instruction) on the stack. Once you have jumped there is no record, on the stack or anywhere else, of where you jumped from.

Call instructions will save the address of the next instruction on the stack (push) and then do a jump to the new address.

Return instructions will read the saved return address from the stack (pop) and jump to it. Thus returning control to the point after the call.

In this way in normal code every call to a function is matched by a return from that function linked together by that return address saved on the stack.

Now the "magic" part of this thread switching, which is a bit hard to get ones head around, is that it never uses call. All the call instructions in your code are irrelevant to understanding the thread switch.

All the thread switching is done by means of the ret instruction at the end of the switch() function.

Basically a thread saves all the registers it needs onto the stack, using mov in your code but I think push could be used as well. It saves the address of the next thread onto the stack. Then it does a ret. That ret pops the address of the next thread off the stack and jumps to it.

BINGO the thread switch has happened on executing the ret. No call involved.

As a mental exercise, imagine your switch function had a jmp at the end, after the ret that jumped back to the first instruction on swictch(). Such that switch()never returns. Now your entire thread is just a loop around the instructions inside switch(). The system would constantly swap threads, at the ret all of which are running around that loop, no call instructions involved in thread switching at all !

4 Likes

Oh yes. finally i did understand whats going on. I didn't rely fully on chat gpt. Instead asked for articles about those instructions. So I might missread about jump part. They were in one paragraph so i thought all of them include saving rip

By the way doesn't call switch save rip in current stack? As you said:

So this is what that inline assembly do. No?

Is there any way to make the spawn take a Fn() object instead of a function pointer and make the runtime immediately run without explicitly calling the run method?

I'm not sure but you can use Fn traits there just by changing it to Fn(), FnMut() FnOnce() depending on what you need. But i guess you have to wrap them in box which is not what you want for spawn

About runtime. At any point in time you have to explicitly start a runtime. At least to my knowledge. Tokio for example uses block_on method for this

Re:

The ThreadContext struct in your example has #[derive(Debug, Default)] and is initialised with ctx: ThreadContext::default(), when creating a new Thread. The default values of u64 are zero. So nothing to do there.

Yes, you can, but it's starting to push the boundaries of this very simplified (and wildly unsafe) example.

See the example readme for a bit more information.