Async architectural question - tokio, read consumed object state

Hi guys, I'm trying to implement module to drive the rollers in my home using relays connected over modbus to the server.

To manage opening and closing rollers I had to create some modules that will protect against ex. turning on "up" and "down" roller relay in the same time.

Here is a relay trait:

pub mod relay_modbus;

use async_trait::async_trait;

#[cfg(test)]
use mockall::automock;

#[cfg_attr(test, automock)]
#[async_trait] // Add async_trait attribute
pub trait Relay : Send {
    async fn relay_on(&mut self) -> Result<(), std::io::Error>;
    async fn relay_off(&mut self) -> Result<(), std::io::Error>;
    async fn relay_flip(&mut self) -> Result<(), std::io::Error>;
    async fn relay_status(&self) -> Result<bool, std::io::Error>;
}

And Blinder module:

use tokio::sync::mpsc;
use std::sync::Arc;
use tokio::sync::Mutex;

use crate::relay::Relay;

#[derive(Clone, PartialEq, Debug)]
pub enum BlinderState {
    Opening,
    Closing,
    Stopped,
}

pub struct Blinder {
    relay_up: Arc<Mutex<Box<dyn Relay>>>,
    relay_down: Arc<Mutex<Box<dyn Relay>>>,
    opening_time: u64,
    closing_time: u64,
    stop_rolling_rx: mpsc::Receiver<()>,
    curr_state: Arc<Mutex<BlinderState>>,
}

pub struct BlinderBuilder {
    relay_up: Option<Arc<Mutex<Box<dyn Relay>>>>,
    relay_down: Option<Arc<Mutex<Box<dyn Relay>>>>,
    opening_time: u64,
    closing_time: u64,
    stop_rolling_rx: Option<mpsc::Receiver<()>>, // This should be used for stop rolling.
}

impl BlinderBuilder {
    pub fn new() -> Self {
        BlinderBuilder {
            relay_up: None,
            relay_down: None,
            opening_time: 0,
            closing_time: 0,
            stop_rolling_rx: None,
        }
    }

    pub fn relay_up(&mut self, relay_up: Arc<Mutex<Box<dyn Relay>>>) -> &mut Self {
        self.relay_up = Some(relay_up);
        self
    }

    pub fn relay_down(&mut self, relay_down: Arc<Mutex<Box<dyn Relay>>>) -> &mut Self {
        self.relay_down = Some(relay_down);
        self
    }

    pub fn opening_time(&mut self, opening_time: u64) -> &mut Self {
        self.opening_time = opening_time;
        self
    }

    pub fn closing_time(&mut self, closing_time: u64) -> &mut Self {
        self.closing_time = closing_time;
        self
    }

    pub fn stop_rolling_rx(&mut self, stop_rolling_rx: mpsc::Receiver<()>) -> &mut Self {
        self.stop_rolling_rx = Some(stop_rolling_rx);
        self
    }

    pub fn build(&mut self) -> Blinder {
        Blinder {
            relay_up: self.relay_up.clone().expect("Relay up is required"),
            relay_down: self.relay_down.clone().expect("Relay down is required"),
            opening_time: self.opening_time,
            closing_time: self.closing_time,
            stop_rolling_rx: self.stop_rolling_rx.take().expect("Stop rolling rx is required"),
            curr_state: Arc::new(Mutex::new(BlinderState::Stopped)),
        }
    }
}

impl Blinder {
    pub async fn open(&mut self) -> Result<(), std::io::Error> {
        
        let mut relay_down = self.relay_down.lock().await;
        let relay_down_status = relay_down.relay_status().await?;
        if relay_down_status{
            relay_down.relay_off().await?;
        }

        let mut relay_up = self.relay_up.lock().await;
        relay_up.relay_on().await?;

        {
            let mut curr_state = self.curr_state.lock().await;
            *curr_state = BlinderState::Opening;
        }

        tokio::select! {
            // This sleep is to avoid situation, that if someone forget (or because of some bug)
            // the stop will be never called - then after usual time of opening, the opening relay
            // should be off.
            _ = tokio::time::sleep(tokio::time::Duration::from_secs(self.opening_time)) => {
                relay_up.relay_off().await?;

                let mut curr_state = self.curr_state.lock().await;
                *curr_state = BlinderState::Stopped;
            }
            _ = self.stop_rolling_rx.recv() => {
                relay_up.relay_off().await?;
                
                let mut curr_state = self.curr_state.lock().await;
                *curr_state = BlinderState::Stopped;
            }
        }

        Ok(())
    }

    pub async fn close(&mut self) -> Result<(), std::io::Error> {
        
        let mut relay_up = self.relay_up.lock().await;
        let relay_up_status = relay_up.relay_status().await?;
        if relay_up_status {
            relay_up.relay_off().await?;
        }

        let mut relay_down = self.relay_down.lock().await;
        relay_down.relay_on().await?;

        {
            let mut curr_state = self.curr_state.lock().await;
            *curr_state = BlinderState::Closing;
        }

        tokio::select! {
            _ = tokio::time::sleep(tokio::time::Duration::from_secs(self.closing_time)) => {
                relay_down.relay_off().await?;

                let mut curr_state = self.curr_state.lock().await;
                *curr_state = BlinderState::Stopped;
            }
            _ = self.stop_rolling_rx.recv() => {
                relay_down.relay_off().await?;

                let mut curr_state = self.curr_state.lock().await;
                *curr_state = BlinderState::Stopped;
            }
        }

        Ok(())
    }

    pub async fn stop(&mut self) -> Result<(), std::io::Error> {
        let mut relay_up = self.relay_up.lock().await;
        let mut relay_down = self.relay_down.lock().await;

        relay_up.relay_off().await?;
        relay_down.relay_off().await?;

        let mut curr_state = self.curr_state.lock().await;
        *curr_state = BlinderState::Stopped;

        Ok(())
    }

    pub async fn get_curr_state(&self) -> BlinderState {
        let curr_state = self.curr_state.lock().await;
        curr_state.clone()
    }

}

Then, for ex. during opening or closing I wan't to be able to check current state. For this I've created a get_curr_state function. I wrote the test to show how I want to use it, and

Here is the test definition

#[test]
    fn test_curr_state_function_opening() {
        let rt = Runtime::new().unwrap();

        rt.block_on(async {
            let (_tx, _rx) = mpsc::channel(1);

            let mut relay_up_mock = MockRelay::new();
            let mut relay_down_mock = MockRelay::new();

            relay_down_mock.expect_relay_status()
                .times(1)
                .returning(|| Ok(false));

            relay_up_mock.expect_relay_on()
                .times(1)
                .returning(|| Ok(()));

            relay_up_mock.expect_relay_off()
                .never()
                .returning(|| Ok(()));

            relay_down_mock.expect_relay_on()
                .never()
                .returning(|| Ok(()));

            relay_down_mock.expect_relay_off()
                .never()
                .returning(|| Ok(()));

            let mut blinder = Blinder {
                relay_up: Arc::new(Mutex::new(Box::new(relay_up_mock) as Box<dyn Relay>)),
                relay_down: Arc::new(Mutex::new(Box::new(relay_down_mock) as Box<dyn Relay>)),
                opening_time: 5,
                closing_time: 1,
                stop_rolling_rx: _rx,
                curr_state: Arc::new(Mutex::new(BlinderState::Stopped)),
            };

            let state_clone = Arc::clone(&blinder.curr_state);

            {
                let curr_state = state_clone.lock().await;
                assert_eq!(*curr_state, BlinderState::Stopped);
            }
            
            tokio::spawn(async move {
                blinder.open().await.unwrap();
            });

            // Wait 1 second
            tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;

            let curr_state = state_clone.lock().await;
            assert_eq!(*curr_state, BlinderState::Opening);
        });
    }

And I'm able to check state by getting to curr_state directly, (let state_clone = Arc::clone(&blinder.curr_state);) but I want to use blinder.get_current_state() but I can't because the blinder was consumed here:

tokio::spawn(async move {
        blinder.open().await.unwrap();
     });

Probably I made some mistake in the architecture. So the question is:

How to change it to be able to get current state when blinder is opening or closing?

Maybe I should not block in the open or close function, but then, how to implement the "safety mechanism" that will turn off relays after timeout (to not power the closing or opening engine of the blinder)?

Probably I do not understand a lot yet - it's my first RUST app (I have Embedded C background).

Thanks!

The very basic generic answer would be to make it possible to get the state through a different object than the Blinder itself. For example, if you make a second struct which has a clone of the Arc<Mutex<BlinderState>>, and which has a method that locks the mutex and copies the state out, then you can read the state any time.

struct ReadState {
    curr_state: Arc<Mutex<BlinderState>>,
}

impl ReadState {
    pub async fn get(&self) -> BlinderState {
        self.curr_state.lock().await.clone()
    }
}

impl Blinder {
    pub fn state(&self) -> ReadState {
        ReadState { curr_state: self.curr_state.clone() }
    }
}

However, I would instead recommend that you consider redesigning your code more significantly. For example, it is not cancellation-safe: if open() or close() is called, and the task is cancelled midway through the operation (i.e. the future returned by open() or close() is dropped), then it will never command relay_off().

Instead, I think you should take an actor / state-machine approach — there is a single spawned task (spawned by BlinderBuilder::build(), never exposed to the caller) which owns the relays and receives commands (Open, Close, Stop) over a channel. This also has the advantage that it can use many fewer mutexes, which means you can be assured that the code can never be deadlocked or even blocked by other code using one of the mutexes.

Essentially, take the approach you used for stop_rolling() and use it for all three commands. Then write a loop that select!s between receiving a new command and the timeout occurring, and spawn a task for that loop. (Don't forget that if you receive two Opens in a row, the time shouldn't be extended; you may want to keep track of a number that's "current estimated position".)

For the state checking, the loop can write to a mutex or a watch channel which is kept accessible along with the command sender.

5 Likes

Thank you, Kevin, for the comprehensive explanation. I will definitely redesign the code in the way you've described, with a separate task for each blind. I will also redesign the Modbus relays module (which is not described here, and contains 16 channels) in such a way that the relay objects can be taken (ownership transferred) during blind creation. This ensures that the relay is not used by another blind.