KAIO v0.2.0: Write GPU kernels in Rust, tensor-core matmul at 92.5% of cuBLAS sgemm

Hey all, I just published v0.2.0 of KAIO, a Rust-native GPU kernel framework. Wanted to share it here since it's aimed squarely at a gap in the Rust ecosystem.

The short version: you write #[gpu_kernel] fn my_kernel(...) in Rust, the proc macro lowers it to PTX at build time, and the runtime dispatches via the NVIDIA driver. No CUDA toolkit needed to build, no Python, works on Windows and Linux.

The tensor-core matmul hits 92.5% of cuBLAS sgemm at 4096² on an RTX 4090 (fp16 inputs, fp32 accumulation, cp.async double-buffered). The sync variant lands at 82.3%.

The motivating use case is custom ops, if you're using candle or burn and need a fused activation, a novel attention variant, or a quantization kernel that the framework doesn't support, your options today are CUDA C++ and FFI bindings. Python developers reach for triton.jit. KAIO is an attempt at the Rust equivalent.

Here's what a kernel looks like: the gated SiLU activation from LLaMA/Mistral/Qwen, normally hand-written CUDA:

#[gpu_kernel(block_size = 256)]
fn fused_silu_gate(x: &[f32], gate: &[f32], out: &mut [f32], n: u32) {
    let idx = thread_idx_x() + block_idx_x() * block_dim_x();
    if idx < n {
        let xi = x[idx];
        let sig = 1.0f32 / (1.0f32 + exp(-xi));
        out[idx] = xi * sig * gate[idx];
    }
}

Limitations I want to be upfront about: NVIDIA-only (SM 7.0+), inference-focused (no autograd), pre-1.0 API. The kernel DSL is a Rust subset: no closures, traits, or generics inside kernel bodies. Small-shape matmul trails cuBLAS significantly. The tensor-core comparison is fp16→fp32 vs cuBLAS sgemm f32→f32, a project-local baseline, not a precision-identity claim.

One thing that surprised me during development: the 92.5% didn't come from vectorized loads. The jump came from bank-conflict padding on the shared B-tile plus hoisting (group_id, thread_id_in_group) out of the fragment loaders. The async path benefited disproportionately (+7.4pp) because cp.async was already saturating global bandwidth, the real bottleneck was shared-memory contention at fragment-read time.

Links:

Would love to hear what custom kernels people would want to write first, and whether the DSL subset feels too limiting. Also happy to go deep on any of the PTX generation or tiling work if anyone's curious.

Support for weird quantization like 6-bit floats, or that's to be written manually/on CPU side?

Not today. KAIO's type system is f32/f16/bf16, no sub-byte types yet. 6-bit specifically is tricky because it doesn't align to byte boundaries (4 values pack into 3 bytes), so the unpack logic needs shift-and-mask operations that straddle bytes.

Sub-byte quantization is on the roadmap. INT8 is the Phase 7 priority, then 4-bit since that's where most quantized LLM inference lives right now. 6-bit (Q6_K-style) would come after, the architecture for all of these is the same: a dequant stage that unpacks to fp16 during the cooperative tile load, then feeds into the existing mma.sync tensor-core pipeline. The TC compute path stays the same, it just gets a dequant step in front.

What format are you targeting? Q4_K, Q6_K, AWQ, GPTQ, something else? This is the first feature request so it goes into my planning backlog immediately, knowing the concrete use case helps me prioritize which format lands first.

That really depends on the exact model and how its quality degrades, so easiest to say "I don't know in advance". 4/5/6-bit quantization, probably?

Honest answer: I'm not planning to ship Q4_K/Q5_K/Q6_K as pre-baked ops. The list of quant formats never ends, and anything I bake in now is stale by the next paper. My direction is to ship INT8 dequantize-matmul as the reference template (next milestone) and make sure the DSL supports writing 4/5/6-bit variants yourself.

The primitives for that: bitwise ops, signed/unsigned shifts, compound bitwise assign, just shipped in v0.2.1, and v0.2.2 went to crates.io yesterday. So a 6-bit unpack kernel in KAIO's Rust DSL is already possible right now: shift-and-mask over a packed u8/u32 stream, ~30 lines of idiomatic Rust. Once the INT8 showcase lands you'd have a template to copy for whichever bit-width your model actually needs.

"I don't know in advance" is the shape I'm targeting, you pick the bit-width per model, you write the dequant yourself, you don't wait on me to add your format.

@ProgramCrafter

Quick follow-up with actual Rust code, since "~30 lines of Rust" is easy to say and harder to prove. Symmetric INT8 dequant in KAIO's DSL:

#[gpu_kernel(block_size = 256)]
fn dequant_i8(packed: &[u32], out: &mut [f32], scale: f32, n_words: u32) {
    let tid = thread_idx_x();

    if tid < n_words {
        let word = packed[tid];

        // Extract each byte, sign-extend to i32, cast to f32, scale.
        let b0 = (((word & 0xFF) as i32) << 24) >> 24;
        let b1 = ((((word >> 8) & 0xFF) as i32) << 24) >> 24;
        let b2 = ((((word >> 16) & 0xFF) as i32) << 24) >> 24;
        let b3 = ((((word >> 24) & 0xFF) as i32) << 24) >> 24;

        let base = tid * 4;
        out[base] = (b0 as f32) * scale;
        out[base + 1] = (b1 as f32) * scale;
        out[base + 2] = (b2 as f32) * scale;
        out[base + 3] = (b3 as f32) * scale;
    }
}

Four signed int8 weights packed per u32, sign-extended via the << 24 >> 24 trick (the >> on i32 is arithmetic shift, which preserves negatives). Bit-exact against a CPU reference. For 6-bit, same pattern with mask 0x3F and different shift counts.