Mocking A Trait in Another Module (in The Same Workspace)

Hi All!

The problem is package A has a public async trait (using async_trait) with an async method. Package B - the axum router - consumes that trait which is injected as an extension.

How can a test double (mock, spy, stub) that implements that trait be built with some tool (like mockall or another)?

I tried to manually craft the test double, but I had to change the receiver in the trait from &self to &mut self to be able to track the state (calls, params etc) - which I do not like to do. But if it's the only option, I guess it has to do.

I don't know about using a mock crate to sugar over the trait implementation, but you can fix the state tracking thing in your hand-written impl with some good old fashioned interior mutability:

use tokio::sync::Mutex;

struct Mock {
    state: Mutex<State>,
}

struct State(i32);

#[async_trait]
impl AsyncTrait for Mock {
    async fn do_something(&self) {
        let mut state = self.state.lock().await;

        // Feel free to mutate `state` all you need!
        state.0 = 42;
    }
}

edit: The previous revision of the example code wrapped the mutex in Arc, but the shared ownership may not be necessary for this use case.

1 Like

It was a great help. Thank you! It solved the problem, and it is working.

Assuming the trait is:

pub type DynTheAction = Arc<dyn TheAction + Send + Sync>;

#[async_trait]
pub trait TheAction {
    async fn serve_resolved(&self, input: ActionInput) -> Result<ActionOutput, Error>;
}

This is the manually crafted test double:

#[derive(Default, Clone)]
pub struct SpyTheActionState {
    pub input: Vec<ActionInput>,
}

pub struct SpyTheAction {
    pub state: Arc<Mutex<SpyTheActionState>>,
}

#[async_trait]
impl TheAction for SpyTheAction {
    async fn serve_resolved(&self, input: ActionInput) -> Result<ActionOutput, Error> {
        let mut v = self.state.lock().await;
        v.input.push(input);

        Ok(ActionOutput::default())
    }
}

And this is the test:

#[tokio::test]
async fn handler_should_call_action() {
    let mock_state = Arc::new(Mutex::new(SpyTheActionState::default()));
    let mock: SpyTheAction = SpyTheAction {
        state: Arc::clone(&mock_state),
    };
    let req_ctx: Arc<ReqCtx> = Arc::new(ReqCtx {
        settings: Default::default(),
        action: Arc::new(mock),
    });

    let mut r = app_router(req_ctx);
    let target_url = format!("/action/{}", "USER_ID");
    let incoming_payload: IncomingPayload = IncomingPayload::default();

    let req = Request::builder()
        .method(http::Method::POST)
        .uri(target_url)
        .header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref())
        .body(Body::from(
            serde_json::to_vec(&json!(incoming_payload)).unwrap(),
        ))
        .unwrap();

    let _ = r.call(req).await.unwrap();

    let state_of_mock = mock_state.lock().await;

    assert_eq!(1, state_of_mock.input.len());
}
1 Like