How do I make raw pointer Sync (FFI)?

rust-bert uses C++ FFI. I am very sure that I am not going to mutate the global model, it's read-only, just used for predictions.

use color_eyre::eyre;
use rust_bert::pipelines::summarization::SummarizationModel;
use once_cell::sync::Lazy;

struct Model {
    summarizer: SummarizationModel,
}

impl Model {
    pub fn predict(&self, batch: Vec<String>) -> Vec<String> {
        self.summarizer.summarize(&batch)
    }

    fn load() -> Self {
        Self {
            summarizer: SummarizationModel::new(Default::default()).unwrap(),
        }
    }
}

static GLOBAL_MODEL: Lazy<Model> = Lazy::new(|| {
    Model::load()
});

#[tokio::main]
async fn main() -> eyre::Result<()> {
    Ok(())
}

I get this error

`*mut torch_sys::C_tensor` cannot be shared between threads safely
within `Model`, the trait `Sync` is not implemented for `*mut torch_sys::C_tensor`
required because of the requirements on the impl of `Sync` for `once_cell::imp::OnceCell<Model>`
required because of the requirements on the impl of `Sync` for `once_cell::sync::Lazy<Model>`
shared static variables must have a type that implements `Sync`

Raw pointers are by definition not Sync because they can't be shared safely between threads - there is no way to prevent data races. The Rust type system protects you from making this mistake by preventing you from being able to implement Sync (unsafely) for a foreign type.
If you want to use a raw pointer from multiple threads, stick it into a Mutex or RwLock (depending on your situation).
If you are certain that only one thread can access it at any time, then you can implement a NullLock - which is essentially a type that lies to the compiler that it provides synchronization (by implementing the Sync trait) but actually does nothing. This is a dangerous solution, but it works. On the other hand, the only place to use this solution is in no-std environments where you either don't have atomic instructions (some weird architecture) or cannot use them (eg: AArch64 when you don't have the MMU set up). In all other cases, just use a Mutex. If you have no contention, the overhead is really low.

1 Like

It seems ok for PyTorch C++ types to be read in parallel. You can ask the tch crate - which is used by the rust_bert crate - to add Sync impl on their types.

Something to keep in mind is that while you don't directly mutate the variable, that doesn't mean the C++ code won't mutate internal state as part of doing inference. For example, it's not uncommon to lazily allocate buffers or set up GPU resources on first run.

That said, I'm kinda curious why you are storing the model in a global variable instead of instantiating it in main() and passing it down to functions when they need access. Then you can side-step the Sync issues entirely.

2 Likes

I plan on creating a web service for it. I create a dedicated thread per request (implement batching later - like rust-dl-webserver/main.rs at master · epwalsh/rust-dl-webserver · GitHub)

use color_eyre::eyre::{self, Context};
use rust_bert::pipelines::summarization::SummarizationModel;

struct Model {
    summarizer: SummarizationModel,
}

impl Model {
    pub fn predict(&self, batch: Vec<String>) -> Vec<String> {
        self.summarizer.summarize(&batch)
    }

    fn load() -> Self {
        Self {
            summarizer: SummarizationModel::new(Default::default()).unwrap(),
        }
    }
}

async fn predict_batch(input: Vec<String>) -> eyre::Result<Vec<String>> {
    let (tx, rx) = tokio::sync::oneshot::channel();

    // CPU bound task
    std::thread::spawn(move || {
        let handler = |inp: Vec<String>, model: &Model| -> Vec<String> {
            {
                let output = model.predict(inp.clone());
                output
            }
        };
        let model = Model::load();
        let output = handler(input, &model);
        // tx.send(output)
        tx.send(output).unwrap();
    });
    // handle.join().map_err(|_| eyre::eyre!("thread panic"))?;
    Ok(rx.await.wrap_err("no message here!")?)
}

#[tokio::main]
async fn main() -> eyre::Result<()> {
    let txt = "The quake struck at 1:08am off Japan's southern coast 
at a depth of 45 kilometers (28 miles), but there were no signs 
that it would trigger a a tsunami, according to the Japan Meteorological Agency. 
The quake's magnitude was revised up from a preliminary measurement of 6.4.
In a statement released overnight, the agency warned 
there's a heightened risk of falling rocks and landslides in the areas that witnessed strong shaking. 
It also urged residents to be prepared for the possibility 
that an earthquake with a maximum seismic intensity of 5 or 
higher on Japan's scale will occur in the next few days.
Some areas near the epicenter of the quake lost power overnight, 
but the nearby Sendai and Genkai nuclear power stations are operating normally, 
according to Kyushu Electric Power. 
Train operator JR Kyushu suspended some lines in the Oita and Miyazaki areas Saturday.";

    let work = vec![txt.to_string(); 1];
    let output = predict_batch(work).await?;
    println!("{output:#?}");
    Ok(())
}

You write:

unsafe impl Sync for Model {}

and same for Send.

In this situation ensuring safety is your responsibility. If the model is not in fact safe to access from multiple threads at once, it could get corrupted and crashy. Consult libtorch docs to check how much thread-safety they offer (even APIs that seem to be read-only could internally have mutable caches, lazy init, etc.).

1 Like

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.