Using tokio::sync::OnceCell with tokio::test

I want to preface that this is not a new topic. In fact, it was discussed in an earlier topic. I understand the theory discussed in that thread, but I don't know how to translate it to actual code, as I'm fairly new to tokio runtimes.

Brief description and mwe

I am writing the database driver side of a program and testing it. The driver has a reference to a Props object and tests require that the db be cleared before starting. To ensure this, I have a function that stores the props and driver in static objects, build them, clear the db and return the driver. These tests fail sometimes, which I believe is caused by a race condition due with the tokio runtime.

Here's a mwe:

src/lib.rs:

use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};

#[derive(Debug)]
pub struct Genre {
    pub id: i64,
    pub name: String,
}

impl Genre {
    pub fn new(name: String) -> Self {
        Self { id: 10, name }
    }
}

pub struct DBProps {
    pool: SqlitePool,
}

impl DBProps {
    pub async fn new(url: &str) -> sqlx::Result<Self> {
        Ok(Self {
            pool: SqlitePoolOptions::new().connect(url).await?,
        })
    }
}

pub struct GenresDB<'a> {
    props: &'a DBProps,
}

impl<'a> GenresDB<'a> {
    pub async fn new(props: &'a DBProps) -> Self {
        Self { props }
    }

    pub async fn delete_all(&self) -> sqlx::Result<u64> {
        Ok(sqlx::query!("DELETE FROM Genres")
            .execute(&self.props.pool)
            .await?
            .rows_affected())
    }

    pub async fn insert(&self, genre: Genre) -> sqlx::Result<()> {
        sqlx::query!(
            "INSERT INTO Genres (id, name) VALUES (?, ?)",
            genre.id,
            genre.name
        )
        .execute(&self.props.pool)
        .await?;

        Ok(())
    }

    pub async fn select_all(&self) -> sqlx::Result<Vec<Genre>> {
        sqlx::query_as!(Genre, "SELECT id, name FROM Genres")
            .fetch_all(&self.props.pool)
            .await
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use assertables::*;
    use tokio::sync::OnceCell;

    async fn db() -> &'static GenresDB<'static> {
        static PROPS: OnceCell<DBProps> = OnceCell::const_new();
        static DB: OnceCell<GenresDB> = OnceCell::const_new();

        let props = PROPS
            .get_or_init(async || {
                assert_ok!(dotenvy::dotenv());
                let database_url = std::env::vars()
                    .find(|(key, _)| key == "DATABASE_URL")
                    .unwrap()
                    .1;
                DBProps::new(&database_url).await.unwrap()
            })
            .await;

        let db = DB.get_or_init(async || GenresDB::new(props).await).await;

        assert_ok!(db.delete_all().await);
        db
    }

    #[tokio::test]
    async fn when_db_is_empty_if_user_selects_all_genres_then_they_get_nothing() {
        assert!(db().await.select_all().await.unwrap().is_empty());
    }

    #[tokio::test]
    async fn when_db_has_one_genre_if_user_selects_all_then_they_get_it() {
        let db = db().await;
        db.insert(Genre::new(String::from("fantasy")))
            .await
            .unwrap();

        let result = db.select_all().await.unwrap();
        assert_eq!(result.len(), 1);
        assert_some!(result.first());
    }
}

Cargo.toml:

[package]
name = "mwe"
version = "0.1.0"
edition = "2024"

[dependencies]
dotenvy = "0.15.7"
sqlx = { version = "0.8", features = [ "runtime-tokio", "sqlite" ] }

[dev-dependencies]
assertables = "9.6"
tokio = { version = "1.45", features = ["macros"] }

.env:

DATABASE_URL="sqlite://$PWD/db.sqlite3"

db.sql:

CREATE TABLE Genres (
    id INTEGER PRIMARY KEY,
    name VARCHAR(127) UNIQUE NOT NULL);

What I know

As I understand from the thread I linked at the top, even though the OnceCell's lifetime is static, the value it holds is tied to the runtime that built it. Since every test creats its own runtime, there's a race condition between the destruction of the runtime and the retrieval of GenresDB by the next test. In the full program, I get the following message:

An error occured during the attempt of performing I/O: An error occured during the attempt of performing I/O: A Tokio 1.x context

I'm not getting it here and the errors are not so consistent, but I believe this is because I've removed something that made the program go a bit slower?

What I don't know

The thread I linked at the top explains how to solve this issue (can't link due to new user limit):

If the OnceCell is defined in your test, then you could put the OnceCell in a thread-local since tests use a single-threaded runtime meaning that there's a one-to-one mapping from thread to runtime. However, if it's defined in the library, then that might not work.

Other options would be:

  1. Put a runtime in a (non-async) OnceCell and have each test use #[test] + a block_on on the shared runtime. (I.e. you're not using #[tokio::test])
  2. Remove the global from your library and have it be an argument. Then create a separate value in each test.

However, I don't know how this would translate to actual code. I've tried to use thread_local! without success and investigated how to manually use runtimes without success either. I think the thread local option is my best bet, but I'm not sure as I'm fairly new to tokio.

What am I missing here? How can I change the tests to work consistently? Thank you!

Could it be that they meant a std::sync::OnceLock (since tokio::sync::OnceCell is actually a lock)? If your tests run in parallel (which they do by default), a std::cell::OnceCell in a thread local would give each thread its own runtime, which if I am reading correctly, is not intended.

1 Like

I didn't know about tests running in parallel! I see that I can use cargo test -- --test-threads=1. I also see that I can use serial_test to run them sequentially in code rather than from the shell. I'll investigate a bit more there.

Using std::sync::OnceLock I get the following error:

error[E0271]: expected `{async closure@lib.rs:74:26}` to return `DBProps`, but it returns `{async closure body@src/lib.rs:74:35: 81:14}`
   --> src/lib.rs:74:26
    |
74  |               .get_or_init(async || {
    |  ______________-----------_^
    | |              |
    | |              required by a bound introduced by this call
75  | |                 assert_ok!(dotenvy::dotenv());
76  | |                 let database_url = std::env::vars()
77  | |                     .find(|(key, _)| key == "DATABASE_URL")
...   |
80  | |                 DBProps::new(&database_url).await.unwrap()
81  | |             });
    | |_____________^ expected `DBProps`, found `async` closure body
    |
    = note:            expected struct `DBProps`
            found `async` closure body `{async closure body@src/lib.rs:74:35: 81:14}`
note: required by a bound in `OnceLock::<T>::get_or_init`
   --> /home/groctel/.rustup/toolchains/stable-x86_64-unknown-linux-gnu/lib/rustlib/src/rust/library/std/src/sync/once_lock.rs:308:24
    |
306 |     pub fn get_or_init<F>(&self, f: F) -> &T
    |            ----------- required by a bound in this associated function
307 |     where
308 |         F: FnOnce() -> T,
    |                        ^ required by this bound in `OnceLock::<T>::get_or_init`

For reference, I would be changing db() to look like this:

    async fn db() -> &'static GenresDB<'static> {
        static PROPS: OnceLock<DBProps> = OnceLock::new();
        static DB: OnceLock<GenresDB> = OnceLock::new();

        let props = PROPS
            .get_or_init(async || {
                assert_ok!(dotenvy::dotenv());
                let database_url = std::env::vars()
                    .find(|(key, _)| key == "DATABASE_URL")
                    .unwrap()
                    .1;
                DBProps::new(&database_url).await.unwrap()
            });

        let db = DB.get_or_init(async || GenresDB::new(props).await);

        assert_ok!(db.delete_all().await);
        db
    }

This error is why I came to use tokio::sync::OnceCell in the first place. Am I doing something wrong here? From what I researched, it looks like I cannot use async functions with OnceLock.