Use jit model to implement noise ruduction in real-time

use std::thread;
use std::time::Duration;
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use std::mem::MaybeUninit;
use ringbuf::{Producer, Consumer, HeapRb, SharedRb};
use tch::{Tensor, kind, CModule, IValue};
use anyhow::{Result, bail};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
pub type RbProd = Producer<f32, Arc<SharedRb<f32, Vec<MaybeUninit<f32>>>>>;
pub type RbCons = Consumer<f32, Arc<SharedRb<f32, Vec<MaybeUninit<f32>>>>>;

fn main() {
    let model = CModule::load(
        "stream_model.jit"
    ).unwrap();
    
    // Set up the ring buffers for audio input and output
    let frame_size = 480;
    let in_rb = HeapRb::<f32>::new(frame_size*100);
    let out_rb = HeapRb::<f32>::new(frame_size*100);
    let (in_prod, in_cons) = in_rb.split();
    let (out_prod, out_cons) = out_rb.split();

    // Create atomic control flags for stopping
    let should_stop = Arc::new(AtomicBool::new(false));
    let should_stop_amp = Arc::new(AtomicBool::new(false));
    let should_stop_out = Arc::new(AtomicBool::new(false));

    // Create threads for audio input, amplification, and output
    let input_thread = thread::spawn(move || audio_input(in_prod, should_stop.clone()));
    let amplification_thread = thread::spawn(move || audio_worker(&model, in_cons, out_prod, should_stop_amp.clone()));
    let output_thread = thread::spawn(move || audio_output(out_cons, should_stop_out.clone()));

    // Start threads
    input_thread.join().unwrap();
    amplification_thread.join().unwrap();
    output_thread.join().unwrap();
}

fn audio_input(mut rb_prod: RbProd, should_stop: Arc<AtomicBool>) {
    let host = cpal::default_host();
    let input_device = host.default_input_device().expect("no input device available");
    let should_stop_ref = &should_stop.clone();
    let config = input_device.default_input_config().unwrap();
    let input_stream = input_device.build_input_stream(
        &config.into(),
        move |data: &[f32], _: &cpal::InputCallbackInfo| {
            if should_stop.load(Ordering::Relaxed) {
                return;
            }
            let samples: Vec<f32> = data.iter().map(|&x| x as f32).collect();
            rb_prod.push_slice(&samples); //samples.len() = 960
        },
        |err| eprintln!("Error in input stream: {}", err),
        None,
    ).unwrap();

    input_stream.play().unwrap();

    // Wait for the stop signal      
    while !should_stop_ref.load(Ordering::Relaxed) {
        thread::sleep(Duration::from_secs_f32(0.01));
    }
}

fn audio_worker(model: &CModule, mut rb_cons: RbCons, mut rb_prod: RbProd, should_stop: Arc<AtomicBool>) {
    let mut states = IValue::Tensor(Tensor::zeros(&[45304], kind::FLOAT_CPU));

    // Sleep until initialization is complete
    while !should_stop.load(Ordering::Relaxed) && !rb_cons.is_full() {
        thread::sleep(Duration::from_millis(100));
    }

    while !should_stop.load(Ordering::Relaxed) {
        let samples = rb_cons.pop_iter().collect::<Vec<f32>>();
        let mut output: Vec<Vec<f32>> = Vec::new();
        if samples.len() == 960{
            let input: Vec<Vec<f32>> = samples.chunks(480).map(|chunk| chunk.to_vec()).collect();
            for chunk in input.iter(){
                
                let input_chunk = tch::IValue::from(Tensor::try_from(chunk.to_vec()).unwrap());
                let (new_audio, new_state) = run_infer(&model, input_chunk, states).unwrap();
                states = IValue::Tensor(new_state); 
                output.push(new_audio);
            }
            let denoised = output.concat();
            
            rb_prod.push_slice(&denoised);
        }
        else{
            continue;
        }
    }
}

fn run_infer(model:&CModule, audio_chunk: IValue, states: IValue) -> Result<(Vec<f32>, Tensor)>{

    //run the model
    let output = &model.forward_is(&[audio_chunk, states]).unwrap();

    //split the output to the model real output
    let (enh_chunk, new_state, _lsnr) = match output {
        IValue::Tuple(ivalues) => match &ivalues[..] {
            [IValue::Tensor(t1), IValue::Tensor(t2), IValue::Tensor(t3)] => (t1.shallow_clone(), t2.shallow_clone(), t3.shallow_clone()),
            _ =>  bail!("unexpected output {:?}", ivalues),
        },
        _ => bail!("unexpected output {:?}", output),
    };

    //transfer tensor to Vec<f32>
    let audio = Vec::<f32>::try_from(enh_chunk).unwrap();
    Ok((audio, new_state))
}

fn audio_output(mut rb_cons: RbCons, should_stop: Arc<AtomicBool>) {
    let host = cpal::default_host();
    let output_device = host.default_output_device().expect("no output device available");
    let should_stop_ref = &should_stop.clone();

    // Sleep until initialization is complete
    while !should_stop.load(Ordering::Relaxed) && !rb_cons.is_full() {
        thread::sleep(Duration::from_secs_f32(0.01));
    }

    let config = output_device.default_output_config().unwrap();
    let output_stream = output_device.build_output_stream(
        &config.into(),
        move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
            if should_stop.load(Ordering::Relaxed) {
                return;
            }
            if data.len()==960{
                let samples = rb_cons.pop_iter().collect::<Vec<f32>>();
                if samples.len() == 960{
                    if data.len() != samples.len(){
                        let samples: Vec<f32> = data.iter().map(|&x| x as f32 - x as f32).collect();
                        data.copy_from_slice(&samples);
                    }
                    else{
                        data.copy_from_slice(&samples);
                    }
                }
            }
        },
        |err| eprintln!("Error in output stream: {}", err),
        None,
    ).unwrap();

    output_stream.play().unwrap();

    // Wait for the stop signal
    while !should_stop_ref.load(Ordering::Relaxed) {
        thread::sleep(Duration::from_secs_f32(0.01));
    }
}

Here is my code, I'm trying to use a JIT model for real-time noise reduction on my PC. The function records speech with noise and simultaneously outputs speech with minimal noise. However, the results are unsatisfactory, as the output appears highly discontinuous.

The initial challenge is how to set the buffer size to 480?. Despite attempting various methods, I couldn't successfully modify the default buffer size of 960. As a workaround, I implemented a for loop to align with the model's input requirements. One input for my noise reduction model consists of 10ms speech with a sample rate of 48000, while the other input is the model's state, which needs to iterate.

The second issue is about ringbuffer, how to synchronizing the producer and the consumer? I suspect that the unsatisfactory outcomes are due to the lack of synchronization. In the audio_output section, to ensure the speaker receives output, I had to establish conditions to match the lengths of 'data' and 'rb_cons.' However, I acknowledge that this approach might not be optimal.

On a related note, when I changed the audio_worker to a simpler function, such as amplifying the audio by two times, the results improved. Nevertheless, occasional minor discontinuities still persisted in the output. The sample code is provided below.

fn audio_worker(mut rb_cons: RbCons, mut rb_prod: RbProd, should_stop: Arc<AtomicBool>) {
    // Sleep until initialization is complete
    while !should_stop.load(Ordering::Relaxed) && !rb_cons.is_full() {
        thread::sleep(Duration::from_millis(100));
    }
    while !should_stop.load(Ordering::Relaxed) {

        let samples = rb_cons.pop_iter().collect::<Vec<f32>>();
        let amplified_samples: Vec<f32> = samples.iter().map(|&x| x * 2.0).collect();
        if amplified_samples.len() == 960{
            rb_prod.push_slice(&amplified_samples);
        }
        else{
            continue;
        }
    }
}

The cargo.toml:

[package]
name = "audio_stream"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
cpal = "0.15"
anyhow = "1.0.75"
ringbuf = "0.3"
rodio = "0.17.3"
tch = "0.14.0"
log = { version = "0.4", features = ["std"] }

These problems have been bothering me for few weeks, hope someone can help me please.