#![feature(naked_functions)]
use std::arch::{asm, naked_asm};
const DEFAULT_STACK_SIZE: usize = 1024 * 1024 * 2;
const MAX_THREADS: usize = 4;
static mut RUNTIME: usize = 0;
pub struct Runtime {
threads: Vec<Thread>,
current: usize,
}
struct Thread {
stack: Vec<u8>,
ctx: ThreadContext,
state: State,
}
#[derive(Debug, PartialEq, Eq)]
enum State {
Available,
Running,
Ready,
}
#[derive(Debug, Default)]
#[repr(C)]
struct ThreadContext {
rsp: u64,
r15: u64,
r14: u64,
r13: u64,
r12: u64,
rbx: u64,
rbp: u64,
}
impl Thread {
fn new() -> Self {
Thread {
stack: vec![0u8; DEFAULT_STACK_SIZE],
ctx: ThreadContext::default(),
state: State::Available,
}
}
}
impl Runtime {
pub fn new() -> Self {
let base_thread = Thread {
stack: vec![0u8; DEFAULT_STACK_SIZE],
ctx: ThreadContext::default(),
state: State::Running,
};
let mut threads = vec![base_thread];
let mut available_threads: Vec<Thread> = (1..MAX_THREADS).map(|_| Thread::new()).collect();
threads.append(&mut available_threads);
Runtime {
threads,
current: 0,
}
}
pub fn init(&self) {
unsafe {
let r_ptr: *const Runtime = self;
RUNTIME = r_ptr as usize;
}
}
pub fn run(&mut self) -> ! {
while self.t_yield() {}
std::process::exit(0)
}
fn t_return(&mut self) {
if self.current != 0 {
self.threads[self.current].state = State::Available;
self.t_yield();
}
}
#[inline(never)]
fn t_yield(&mut self) -> bool {
let mut pos = self.current;
while self.threads[pos].state != State::Ready {
pos += 1;
if pos == self.threads.len() {
pos = 0;
}
if pos == self.current {
return false;
}
}
if self.threads[self.current].state != State::Available {
self.threads[self.current].state = State::Ready;
}
self.threads[pos].state = State::Running;
let old_pos = self.current;
self.current = pos;
unsafe {
let old: *mut ThreadContext = &mut self.threads[old_pos].ctx;
let new: *const ThreadContext = &self.threads[pos].ctx;
asm!("call switch", in("rdi") old, in("rsi") new, clobber_abi("C"));
}
if self.current == 1 {
println!("{:?}", self.threads[self.current].stack)
}
self.threads.len() > 0
}
pub fn spawn(&mut self, f: fn()) {
let available = self
.threads
.iter_mut()
.find(|t| t.state == State::Available)
.expect("no available threads");
let size = available.stack.len();
unsafe {
let s_ptr = available.stack.as_mut_ptr().offset(size as isize);
let s_ptr = (s_ptr as usize & !15) as *mut u8;
std::ptr::write(s_ptr.offset(-16) as *mut u64, guard as u64);
std::ptr::write(s_ptr.offset(-24) as *mut u64, skip as u64);
std::ptr::write(s_ptr.offset(-32) as *mut u64, f as u64);
available.ctx.rsp = s_ptr.offset(-32) as u64;
}
available.state = State::Ready;
}
}
fn guard() {
unsafe {
let rt_ptr = RUNTIME as *mut Runtime;
(*rt_ptr).t_return();
}
}
#[naked]
unsafe extern "C" fn skip() {
unsafe { naked_asm!("ret") }
}
pub fn yield_thread() {
unsafe {
let rt_ptr = RUNTIME as *mut Runtime;
(*rt_ptr).t_yield();
}
}
#[naked]
#[unsafe(no_mangle)]
unsafe extern "C" fn switch() {
unsafe {
naked_asm!(
"mov [rdi + 0x00], rsp",
"mov [rdi + 0x08], r15",
"mov [rdi + 0x10], r14",
"mov [rdi + 0x18], r13",
"mov [rdi + 0x20], r12",
"mov [rdi + 0x28], rbx",
"mov [rdi + 0x30], rbp",
"mov rsp, [rsi + 0x00]",
"mov r15, [rsi + 0x08]",
"mov r14, [rsi + 0x10]",
"mov r13, [rsi + 0x18]",
"mov r12, [rsi + 0x20]",
"mov rbx, [rsi + 0x28]",
"mov rbp, [rsi + 0x30]",
"ret",
)
}
}
fn main() {
let mut rt = Runtime::new();
rt.init();
rt.spawn(|| {
println!("THREAD 1 STARTING");
let id = 1;
for i in 0..2 {
println!("thread: {} counter: {}", id, i);
yield_thread();
}
println!("THREAD 1 FINISHED");
});
rt.spawn(|| {
println!("THREAD 2 STARTING");
let id = 2;
for i in 0..5 {
println!("thread: {} counter: {}", id, i);
yield_thread();
}
println!("THREAD 2 FINISHED");
});
rt.run();
}
I can't wrap my head around 2 things:
Q1: how does CPU knows where is the last executed line in previous thread so it can return there instead of executing from start to finish again?
Q2: do i need manually initialize any of registers in ThreadContext
beside rsp
? I printed them in console and got a bunch of zeroes. Do CPU fills them itself?