I am currently implementing the EZSP and the underlying ASHv2 protocols in Rust with the intention to use them on an embedded smart home gateway to control ZigBee devices.
In this review I present you the ASHv2 host, which constitutes the external interface of the ASHv2 library and is being used by the EZSP library (not part of this review).
The host operates on a serial port, which is passed to the constructor spawn()
.
Since ASHv2 has a mode of operation in which the chip can send packets to the host at any given time, I need a separate listener thread to constantly listen on the serial port and forward those events to an appropriate optional event handler callback
that may be passed to spawn()
, if no current request handler is registered or wants to handle the respective data.
Other than that, the host exhibits an async interface via communicate()
that sends raw bytes to the transmitter and handles appropriate responses. Since a response can consist of an arbitrary amount of data packet responses, the command handler sent with the data must handle received data appropriately to the data expected by the command. Those are implemented e.g. in EZSP for the respective command responses.
For further details, please see the protocol descriptions linked above. Those were the basis for below code.
src/host.rs
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::mpsc::{channel, Sender};
use std::sync::Arc;
use std::thread::{spawn, JoinHandle};
use log::error;
use serialport::TTYPort;
use listener::Listener;
use transmitter::Transmitter;
use crate::packet::FrameBuffer;
use crate::protocol::{Command, Response};
use crate::util::NonPoisonedRwLock;
use crate::Error;
mod listener;
mod transmitter;
/// A host controller to communicate with an NCP via the `ASHv2` protocol.
#[derive(Debug)]
pub struct Host {
running: Arc<AtomicBool>,
command: Sender<Command>,
listener_thread: Option<JoinHandle<()>>,
transmitter_thread: Option<JoinHandle<()>>,
}
impl Host {
/// Creates and starts the host.
///
/// # Errors
/// Returns an [`Error`] if the host could not be started.
pub fn spawn(
serial_port: TTYPort,
callback: Option<Sender<FrameBuffer>>,
) -> Result<Self, Error> {
let running = Arc::new(AtomicBool::new(true));
let (command_sender, command_receiver) = channel();
let connected = Arc::new(AtomicBool::new(false));
let handler = Arc::new(NonPoisonedRwLock::new(None));
let ack_number = Arc::new(AtomicU8::new(0));
let (listener, ack_receiver, nak_receiver) = Listener::new(
serial_port.try_clone_native()?,
running.clone(),
connected.clone(),
handler.clone(),
ack_number.clone(),
callback,
);
let transmitter = Transmitter::new(
serial_port,
running.clone(),
connected,
command_receiver,
handler,
ack_number,
ack_receiver,
nak_receiver,
);
Ok(Self {
command: command_sender,
running,
listener_thread: Some(spawn(move || listener.run())),
transmitter_thread: Some(spawn(move || transmitter.run())),
})
}
/// Communicate with the NCP, returning [`T::Result`](Response::Result).
///
/// # Errors
/// Returns [`T::Error`](Response::Error) if the transactions fails.
pub async fn communicate<T>(&self, payload: &[u8]) -> Result<T::Result, T::Error>
where
T: Clone + Default + Response + Sync + Send + 'static,
{
let response = T::default();
self.command
.send(Command::new(Arc::from(payload), Arc::new(response.clone())))
.map_err(|_| Error::Terminated)?;
response.await
}
}
impl Drop for Host {
fn drop(&mut self) {
self.running.store(false, SeqCst);
if let Some(thread) = self.listener_thread.take() {
thread.join().unwrap_or_else(|_| {
error!("Failed to join listener thread.");
});
}
if let Some(thread) = self.transmitter_thread.take() {
thread.join().unwrap_or_else(|_| {
error!("Failed to join transmitter thread.");
});
}
}
}
src/host/transmitter.rs
use std::collections::HashMap;
use std::fmt::Debug;
use std::iter::Copied;
use std::slice::Iter;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use std::thread::sleep;
use std::time::{Duration, SystemTime};
use itertools::Chunks;
use log::{debug, error, info, trace};
use serialport::TTYPort;
use crate::ash_write::AshWrite;
use crate::error::frame;
use crate::packet::{Data, FrameBuffer, Rst, MAX_PAYLOAD_SIZE};
use crate::protocol::{AshChunks, Command, Event, Handler};
use crate::util::{next_three_bit_number, NonPoisonedRwLock};
use crate::Error;
const MAX_STARTUP_ATTEMPTS: u8 = 5;
const MAX_TIMEOUTS: usize = 4;
const T_REMOTE_NOTRDY: Duration = Duration::from_millis(1000);
const T_RSTACK_MAX: Duration = Duration::from_millis(3200);
const T_RX_ACK_INIT: Duration = Duration::from_millis(1600);
const T_RX_ACK_MAX: Duration = Duration::from_millis(3200);
const T_RX_ACK_MIN: Duration = Duration::from_millis(400);
#[derive(Debug)]
pub struct Transmitter {
// Shared state
serial_port: TTYPort,
running: Arc<AtomicBool>,
connected: Arc<AtomicBool>,
command: Receiver<Command>,
handler: Arc<NonPoisonedRwLock<Option<Arc<dyn Handler>>>>,
ack_number: Arc<AtomicU8>,
ack_receiver: Receiver<u8>,
nak_receiver: Receiver<u8>,
// Local state
buffer: FrameBuffer,
sent: heapless::Vec<(SystemTime, Data), MAX_TIMEOUTS>,
retransmit: heapless::Deque<Data, MAX_TIMEOUTS>,
retransmits: HashMap<u8, usize>,
frame_number: u8,
t_rx_ack: Duration,
}
impl Transmitter {
#[allow(clippy::too_many_arguments)]
pub fn new(
serial_port: TTYPort,
running: Arc<AtomicBool>,
connected: Arc<AtomicBool>,
command: Receiver<Command>,
handler: Arc<NonPoisonedRwLock<Option<Arc<dyn Handler>>>>,
ack_number: Arc<AtomicU8>,
ack_receiver: Receiver<u8>,
nak_receiver: Receiver<u8>,
) -> Self {
Self {
serial_port,
running,
connected,
command,
handler,
ack_number,
ack_receiver,
nak_receiver,
buffer: FrameBuffer::new(),
sent: heapless::Vec::new(),
retransmit: heapless::Deque::new(),
retransmits: HashMap::new(),
frame_number: 0,
t_rx_ack: T_RX_ACK_INIT,
}
}
pub fn run(mut self) {
while self.running.load(SeqCst) {
if let Err(error) = self.main() {
error!("{error}");
self.running.store(false, SeqCst);
break;
}
}
debug!("Terminating.");
}
fn main(&mut self) -> Result<(), Error> {
if self.connected.load(SeqCst) {
if self.handler.read().is_some() {
trace!("Waiting for current transaction to complete.");
Ok(())
} else {
trace!("Processing next command.");
self.process_next_command()
}
} else {
self.initialize()
}
}
fn process_next_command(&mut self) -> Result<(), Error> {
match self.command.recv() {
Ok(command) => self.process_command(command),
Err(error) => {
error!("Error receiving command: {error}");
Ok(())
}
}
}
fn process_command(&mut self, command: Command) -> Result<(), Error> {
trace!(
"Processing command {:#04X?} with handler {:#?}",
&command.payload,
&command.handler
);
self.handler.write().replace(command.handler);
self.transmit_data(&command.payload)
}
fn transmit_data(&mut self, payload: &[u8]) -> Result<(), Error> {
if let Err(error) = payload
.iter()
.copied()
.ash_chunks()
.and_then(|chunks| self.transmit_chunks(chunks.into_iter()))
{
error!("{error}");
self.abort_current_transaction(error);
info!("Re-initializing connection.");
self.initialize()
} else {
debug!("Transmission completed.");
self.set_transmission_completed();
Ok(())
}
}
fn transmit_chunks(&mut self, mut chunks: Chunks<Copied<Iter<u8>>>) -> Result<(), Error> {
let mut transmits;
loop {
if !self.connected.load(SeqCst) {
error!("Connection lost during transaction.");
return Err(Error::Aborted);
}
if !self.running.load(SeqCst) {
error!("Terminated during active transaction.");
return Err(Error::Terminated);
}
self.handle_naks_and_acks();
transmits = 0;
transmits += self.retransmit()?;
transmits += self.push_chunks(&mut chunks)?;
if transmits == 0 && self.is_transaction_complete() {
return Ok(());
}
}
}
fn retransmit(&mut self) -> Result<usize, Error> {
let mut retransmits: usize = 0;
while self.sent.len() < MAX_TIMEOUTS {
if let Some(mut data) = self.retransmit.pop_front() {
let cnt = self.retransmits.entry(data.frame_num()).or_default();
*cnt += 1;
if *cnt > MAX_TIMEOUTS {
error!("Max retransmits exceeded for frame #{}", data.frame_num());
return Err(Error::MaxRetransmitsExceeded);
}
retransmits += 1;
debug!("Retransmitting: {data}");
trace!("{data:#04X?}");
data.set_is_retransmission(true);
if let Err(error) = self.send_data(data) {
error!("Failed to retransmit: {error}");
return Err(error);
}
} else {
break;
}
}
Ok(retransmits)
}
fn push_chunks(&mut self, chunks: &mut Chunks<Copied<Iter<u8>>>) -> Result<usize, Error> {
let mut transmits: usize = 0;
while self.sent.len() < MAX_TIMEOUTS {
if let Some(chunk) = chunks.next() {
transmits += 1;
self.buffer.clear();
self.buffer.extend(chunk);
if let Err(error) = self.send_chunk() {
error!("Error during transmission of chunk: {error}");
return Err(error);
}
} else {
break;
}
}
Ok(transmits)
}
fn send_chunk(&mut self) -> Result<(), Error> {
debug!("Sending chunk.");
trace!("Buffer: {:#04X?}", &*self.buffer);
let data = Data::create(
self.next_frame_number(),
self.ack_number.load(SeqCst),
self.buffer.as_slice().try_into().map_err(|()| {
Error::Frame(frame::Error::PayloadTooLarge {
max: MAX_PAYLOAD_SIZE,
size: self.buffer.len(),
})
})?,
);
self.send_data(data)
}
fn send_data(&mut self, data: Data) -> Result<(), Error> {
debug!("Sending data: {data}");
trace!("{data:#04X?}");
if self.connected.load(SeqCst) {
self.serial_port.write_frame(&data)?;
self.sent
.push((SystemTime::now(), data))
.expect("Send queue should always accept data.");
Ok(())
} else {
error!("Attempted to transmit while not connected.");
Err(Error::Aborted)
}
}
fn handle_naks_and_acks(&mut self) {
self.handle_naks();
self.check_ack_timeouts();
self.handle_acks();
}
fn handle_naks(&mut self) {
#[allow(clippy::needless_collect)] // Polonius issue.
for ack_num in self
.nak_receiver
.try_iter()
.collect::<heapless::Vec<u8, MAX_TIMEOUTS>>()
{
self.handle_nak(ack_num);
}
}
fn handle_nak(&mut self, nak_num: u8) {
if let Some((_, data)) = self
.sent
.iter()
.position(|(_, data)| data.frame_num() == nak_num)
.map(|index| self.sent.remove(index))
{
self.retransmit
.push_back(data)
.expect("Retransmit queue should always accept data.");
}
}
fn handle_acks(&mut self) {
#[allow(clippy::needless_collect)] // Polonius issue.
for ack_num in self
.ack_receiver
.try_iter()
.collect::<heapless::Vec<u8, MAX_TIMEOUTS>>()
{
self.handle_ack(ack_num);
}
}
fn handle_ack(&mut self, ack_num: u8) {
trace!("Handling ACK: {ack_num}");
if let Some((timestamp, data)) = self
.sent
.iter()
.position(|(_, data)| next_three_bit_number(data.frame_num()) == ack_num)
.map(|index| self.sent.remove(index))
{
trace!("ACKed packet #{}", data.frame_num());
if let Ok(duration) = SystemTime::now().duration_since(timestamp) {
self.update_t_rx_ack(Some(duration));
}
}
}
fn check_ack_timeouts(&mut self) {
let now = SystemTime::now();
while let Some((_, data)) = self
.sent
.iter()
.position(|(timestamp, _)| {
now.duration_since(*timestamp)
.map_or(false, |duration| duration > self.t_rx_ack)
})
.map(|index| self.sent.remove(index))
{
self.retransmit
.push_back(data)
.expect("Retransmit queue should always accept data.");
self.update_t_rx_ack(None);
}
}
fn update_t_rx_ack(&mut self, last_ack_duration: Option<Duration>) {
self.t_rx_ack = if let Some(duration) = last_ack_duration {
self.t_rx_ack * 7 / 8 + duration / 2
} else {
self.t_rx_ack * 2
}
.clamp(T_RX_ACK_MIN, T_RX_ACK_MAX);
}
fn next_frame_number(&mut self) -> u8 {
let frame_number = self.frame_number;
self.frame_number = next_three_bit_number(frame_number);
frame_number
}
fn initialize(&mut self) -> Result<(), Error> {
let mut sent_rst_timestamp: SystemTime;
for attempt in 1..=MAX_STARTUP_ATTEMPTS {
debug!("Establishing ASH connection. Attempt #{attempt}");
self.reset();
sent_rst_timestamp = SystemTime::now();
debug!("Waiting for NCP to start up.");
while !self.connected.load(SeqCst) {
trace!("Waiting for NCP to become ready.");
sleep(T_REMOTE_NOTRDY);
match SystemTime::now().duration_since(sent_rst_timestamp) {
Ok(duration) => {
trace!("Time passed: {duration:?}");
if duration > T_RSTACK_MAX {
break;
}
}
Err(error) => {
error!("System time jumped: {error}");
sent_rst_timestamp = SystemTime::now();
}
}
}
if self.connected.load(SeqCst) {
debug!("ASH connection established.");
return Ok(());
}
}
error!("Failed to establish ASH connection.");
Err(Error::InitializationFailed)
}
fn reset(&mut self) {
debug!("Resetting connection.");
self.connected.store(false, SeqCst);
trace!("Sending RST.");
self.serial_port
.write_frame(&Rst::default())
.unwrap_or_else(|error| error!("Failed to send RST: {error}"));
self.reset_state();
}
fn reset_state(&mut self) {
debug!("Resetting state.");
self.abort_current_transaction(Error::Aborted);
self.buffer.clear();
self.sent.clear();
self.retransmits.clear();
self.retransmit.clear();
self.frame_number = 0;
self.ack_number.store(0, SeqCst);
self.t_rx_ack = T_RX_ACK_INIT;
}
fn abort_current_transaction(&self, error: Error) {
if let Some(handler) = self.handler.write().take() {
handler.abort(error);
handler.wake();
}
}
fn set_transmission_completed(&mut self) {
if let Some(handler) = self.handler.read().clone() {
debug!("Finalizing data command.");
handler.handle(Event::TransmissionCompleted);
}
}
fn is_transaction_complete(&self) -> bool {
self.sent.is_empty() && self.retransmit.is_empty()
}
}
src/host/listener.rs
use std::fmt::Debug;
use std::io::ErrorKind;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::atomic::{AtomicBool, AtomicU8};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::Arc;
use log::{debug, error, trace, warn};
use serialport::TTYPort;
use crate::ash_read::AshRead;
use crate::ash_write::AshWrite;
use crate::frame::Frame;
use crate::packet::{Ack, Data, Error, FrameBuffer, Nak, Packet, RstAck};
use crate::protocol::{Event, HandleResult, Handler, Mask};
use crate::util::{next_three_bit_number, NonPoisonedRwLock};
#[derive(Debug)]
pub struct Listener {
// Shared state
serial_port: TTYPort,
running: Arc<AtomicBool>,
connected: Arc<AtomicBool>,
handler: Arc<NonPoisonedRwLock<Option<Arc<dyn Handler>>>>,
ack_number: Arc<AtomicU8>,
callback: Option<Sender<FrameBuffer>>,
ack_sender: Sender<u8>,
nak_sender: Sender<u8>,
// Local state
buffer: FrameBuffer,
is_rejecting: bool,
last_received_frame_number: Option<u8>,
}
impl Listener {
pub fn new(
serial_port: TTYPort,
running: Arc<AtomicBool>,
connected: Arc<AtomicBool>,
handler: Arc<NonPoisonedRwLock<Option<Arc<dyn Handler>>>>,
ack_number: Arc<AtomicU8>,
callback: Option<Sender<FrameBuffer>>,
) -> (Self, Receiver<u8>, Receiver<u8>) {
let (ack_sender, ack_receiver) = channel();
let (nak_sender, nak_receiver) = channel();
let listener = Self {
serial_port,
running,
connected,
handler,
ack_number,
callback,
ack_sender,
nak_sender,
buffer: FrameBuffer::new(),
is_rejecting: false,
last_received_frame_number: None,
};
(listener, ack_receiver, nak_receiver)
}
pub fn run(mut self) {
while self.running.load(SeqCst) {
match self.read_frame() {
Ok(packet) => {
if let Some(ref frame) = packet {
self.handle_frame(frame);
}
}
Err(error) => error!("{error}"),
}
}
debug!("Terminating.");
}
fn handle_frame(&mut self, frame: &Packet) {
debug!("Received: {frame}");
trace!("{frame:#04X?}");
if self.connected.load(SeqCst) {
match frame {
Packet::Ack(ref ack) => self.handle_ack(ack),
Packet::Data(ref data) => self.handle_data(data),
Packet::Error(ref error) => self.handle_error(error),
Packet::Nak(ref nak) => self.handle_nak(nak),
Packet::RstAck(ref rst_ack) => self.handle_rst_ack(rst_ack),
Packet::Rst(_) => warn!("Received unexpected RST from NCP."),
}
} else if let Packet::RstAck(ref rst_ack) = frame {
self.handle_rst_ack(rst_ack);
} else {
warn!("Not connected. Dropping frame: {frame}");
}
}
fn handle_ack(&mut self, ack: &Ack) {
if !ack.is_crc_valid() {
warn!("Received ACK with invalid CRC.");
}
self.ack_sender
.send(ack.ack_num())
.unwrap_or_else(|error| error!("Failed to forward ACK: {error}"));
}
fn handle_data(&mut self, data: &Data) {
debug!("Received frame: {data:#04X?}");
trace!(
"Unmasked payload: {:#04X?}",
data.payload()
.iter()
.copied()
.mask()
.collect::<FrameBuffer>()
);
if !data.is_crc_valid() {
warn!("Received data frame with invalid CRC.");
self.reject();
} else if data.frame_num() == self.ack_number() {
self.ack_received_data(data.frame_num());
self.is_rejecting = false;
self.last_received_frame_number = Some(data.frame_num());
self.ack_number.store(self.ack_number(), SeqCst);
debug!("Sending ACK to transmitter: {}", data.ack_num());
self.ack_sender
.send(data.ack_num())
.unwrap_or_else(|error| {
error!("Failed to forward ACK: {error}");
});
self.forward_data(data);
} else if data.is_retransmission() {
self.ack_number.store(self.ack_number(), SeqCst);
debug!("Sending ACK to transmitter: {}", data.ack_num());
self.ack_sender
.send(data.ack_num())
.unwrap_or_else(|error| {
error!("Failed to forward ACK: {error}");
});
self.forward_data(data);
} else {
debug!("Received out-of-sequence data frame: {data}");
if !self.is_rejecting {
self.reject();
}
}
}
fn ack_received_data(&mut self, frame_num: u8) {
self.serial_port
.write_frame(&Ack::from_ack_num(next_three_bit_number(frame_num)))
.unwrap_or_else(|error| error!("Failed to send ACK: {error}"));
}
fn forward_data(&mut self, data: &Data) {
debug!("Forwarding data: {data}");
let payload: FrameBuffer = data.payload().iter().copied().mask().collect();
if let Some(handler) = self.handler.write().take() {
debug!("Forwarding data to current handler.");
match handler.handle(Event::DataReceived(Ok(&payload))) {
HandleResult::Completed => {
debug!("Command responded with COMPLETED.");
handler.wake();
}
HandleResult::Continue => {
debug!("Command responded with CONTINUE.");
self.handler.write().replace(handler);
}
HandleResult::Failed => {
warn!("Command responded with FAILED.");
handler.wake();
}
HandleResult::Reject => {
debug!("Command responded with REJECT.");
self.callback.as_ref().map_or_else(|| {
error!("Current response handler rejected received data and there is no callback handler registered. Dropping packet.");
}, |callback| {
debug!("Forwarding rejected data to callback.");
callback.send(payload).unwrap_or_else(|error| {
error!("Failed to send data to callback channel: {error}");
});
});
}
}
} else if let Some(callback) = &self.callback {
debug!("Forwarding data to callback.");
callback.send(payload).unwrap_or_else(|error| {
error!("Failed to send data to callback channel: {error}");
});
} else {
error!("There is neither an active response handler nor a callback handler registered. Dropping packet.");
}
}
fn handle_error(&mut self, error: &Error) {
trace!("Received ERROR: {error:#04X?}");
if !error.is_ash_v2() {
error!("{error} is not ASHv2: {}", error.version());
}
self.connected.store(false, SeqCst);
error.code().map_or_else(
|| {
error!("NCP sent error without valid code.");
},
|code| {
warn!("NCP sent error condition: {code}");
},
);
}
fn handle_nak(&mut self, nak: &Nak) {
if !nak.is_crc_valid() {
warn!("Received ACK with invalid CRC.");
}
debug!("Forwarding NAK to transmitter.");
self.nak_sender
.send(nak.ack_num())
.unwrap_or_else(|error| error!("Failed to forward NAK: {error}"));
}
fn handle_rst_ack(&mut self, rst_ack: &RstAck) {
if !rst_ack.is_ash_v2() {
error!("{rst_ack} is not ASHv2: {}", rst_ack.version());
}
rst_ack.code().map_or_else(
|| {
warn!("NCP acknowledged reset with invalid error code.");
},
|code| {
debug!("NCP acknowledged reset due to: {code}");
},
);
self.reset_state();
self.connected.store(true, SeqCst);
if let Some(handler) = self.handler.write().take() {
trace!("Aborting current command.");
handler.abort(crate::Error::Aborted);
handler.wake();
}
}
fn reset_state(&mut self) {
trace!("Resetting state variables.");
self.buffer.clear();
self.is_rejecting = false;
self.last_received_frame_number = None;
}
fn reject(&mut self) {
trace!("Entering rejection state.");
self.is_rejecting = true;
self.send_nak();
}
fn send_nak(&mut self) {
debug!("Sending NAK: {}", self.ack_number());
self.serial_port
.write_frame(&Nak::from_ack_num(self.ack_number()))
.unwrap_or_else(|error| error!("Could not send NAK: {error}"));
}
fn read_frame(&mut self) -> Result<Option<Packet>, crate::Error> {
self.serial_port
.read_packet_buffered(&mut self.buffer)
.map(Some)
.or_else(|error| {
if let crate::Error::Io(io_error) = &error {
if io_error.kind() == ErrorKind::TimedOut {
return Ok(None);
}
}
Err(error)
})
}
fn ack_number(&self) -> u8 {
self.last_received_frame_number
.map_or(0, next_three_bit_number)
}
}
I'd appreciate feedback on the API design as exposed by Host
as well as the implementation of the transmitter and listener threads.