How to Use a Macro to Generate an Array of Incrementing Parts for a Shift Register Driver in Rust?

Rust beginner here! I'm working on a shift register driver, and I need some help with modifying the split method to return ([Parts;N], ShiftRegister<SPI, N>).

Specifically, I want each instance in the array of Parts to have its byte field increment from 0 to N.

Can a macro be used to generate these Parts items automatically?

Here's the code snippet I'm working with:

use std::sync::Arc;
use std::sync::Mutex;

use embedded_hal as hal;
use hal::blocking::spi::Write;

pub enum DriverError {
    UpdateFailure,
}

pub struct IC74hc594<SPI, const N: usize> {
    spi: SPI,
    state: Arc<Mutex<[u8; N]>>,
}

impl<SPI, const N: usize> IC74hc594<SPI, N>
where
    SPI: Write<u8>,
{
    pub fn new(spi: SPI) -> Self {
        Self {
            spi: spi,
            state: Arc::new(Mutex::new([0; N])),
        }
    }

    pub fn split(self) -> (Parts<N>, ShiftRegister<SPI, N>) {
        (
            Parts {
                Q0: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 0,
                },
                Q1: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 1,
                },
                Q2: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 2,
                },
                Q3: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 3,
                },
                Q4: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 4,
                },
                Q5: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 5,
                },
                Q6: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 6,
                },
                Q7: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: 0,
                    bit: 7,
                },
            },
            ShiftRegister {
                spi: self.spi,
                state: self.state,
            },
        )
    }
}

#[allow(non_snake_case)]
pub struct Parts<const N: usize> {
    Q0: RegisterPin<N>,
    Q1: RegisterPin<N>,
    Q2: RegisterPin<N>,
    Q3: RegisterPin<N>,
    Q4: RegisterPin<N>,
    Q5: RegisterPin<N>,
    Q6: RegisterPin<N>,
    Q7: RegisterPin<N>,
}

pub struct ShiftRegister<SPI, const N: usize> {
    spi: SPI,
    state: Arc<Mutex<[u8; N]>>,
}

impl<SPI, const N: usize> ShiftRegister<SPI, N>
where
    SPI: Write<u8>,
{
    pub fn update(&mut self) -> Result<(), DriverError> {
        let state = Arc::clone(&self.state);
        let mut state = state.lock().unwrap();

        self.spi
            .write(state.as_mut())
            .or(Err(DriverError::UpdateFailure))
    }
}

pub struct RegisterPin<const N: usize> {
    state: Arc<Mutex<[u8; N]>>,
    byte: usize,
    bit: usize,
}

impl<const N: usize> RegisterPin<N> {
    pub fn set(&self) {
        let mut state = self.state.lock().unwrap();
        state[self.byte / 8] |= 1 << self.bit;
    }

    pub fn clear(&self) {
        let mut state = self.state.lock().unwrap();
        state[self.byte / 8] &= !(1 << self.bit);
    }
}

what is the byte field? in your example code, the Parts type doesn't have a byte field. so it's unclear what you want to do.

if you are talking about an array like [Parts::<0>{ ... }, Parts::<1>{ ... }, ... , Parts::<N>{ ... }], then it's not possible, because the Parts type is parameterized over the const N, and rust arrays are homogeneous, you cannot have an array of different types,

Sorry, wasn't clear. The code I have at the moment just creates a single instance of parts. What I wanted was an array of parts, However, the byte field in Each RegisterPin needs to increment.

Resulting in

[
    Parts {
        Q0: RegisterPin {
            state: Arc::clone(&self.state),
            byte: 0,
            bit: 0,
        },
        Q1: RegisterPin {
            state: Arc::clone(&self.state),
            byte: 0,
            bit: 1,
        },

        ...
    },
    Parts {
        Q0: RegisterPin {
            state: Arc::clone(&self.state),
            byte: 1,
            bit: 0,
        },
        Q1: RegisterPin {
            state: Arc::clone(&self.state),
            byte: 1,
            bit: 1,
        },

        ...
    },
    ...
    Parts {
        Q0: RegisterPin {
            state: Arc::clone(&self.state),
            byte: N-1,
            bit: 0,
        },
        Q1: RegisterPin {
            state: Arc::clone(&self.state),
            byte: N-1,
            bit: 1,
        },

        ...
    }
]

Hope this clarifies things a bit.

so it's the RegisterPin type, not Parts then.

I guess you are expecting macros can reduce the repetition of Q0 through Q7, but unfortunately, rust macros cannot generate identifiers like that (even the nightly-only concat_idents!() is very limited).

IMO, macros won't save much in this case, compared to regular functions. for example, with a helper function like this:

// a convevient `Parts` constructor using an array
// since arrays are easier to deal with
// an alternate to this is to use the `From` trait
impl<const N: usize> Parts<N> {
    fn from_pin_array(pins: [RegisterPins<N>; 8]) -> Self {
        let [Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7] =  pins;
        Self { Q0, Q1, Q2, Q3, Q4, Q5, Q6, Q7 }
    }
}

the split() function is not difficult to implement, thanks to std::array::from_fn():

fn split(self) -> [Parts<N>; N] {
    std::array::from_fn(|byte| {
        Parts::from_pin_array(std::array::from_fn(|bit| {
            RegisterPin {
                state: Arc::clone(&self.state),
                byte,
                bit,
            }
        }))
    })
}

I wasn't so much after auto generating the individual RegisterPin fields (Q0..Q7). More to avoid the repetition of each element in the array of Parts.

I adopted you solution as bellow:

    pub fn split(self) -> ([Parts<N>; N], ShiftRegister<SPI, N>) {
        (
            std::array::from_fn(|byte| Parts {
                Q0: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 0,
                },
                Q1: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 1,
                },
                Q2: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 2,
                },
                Q3: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 3,
                },
                Q4: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 4,
                },
                Q5: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 5,
                },
                Q6: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 6,
                },
                Q7: RegisterPin {
                    state: Arc::clone(&self.state),
                    byte: byte,
                    bit: 7,
                },
            }),
            ShiftRegister {
                spi: self.spi,
                state: self.state,
            },
        )
    }

Only question I have is, does this work in no_std builds? I'm hoping to use this driver on a stm32h753 based board that chains 5 shift registers.

Once again. Thanks for you help!

It's just an array, so yes. As long as you only import from core and no_std crates, Rust code will work on no_std.

yes, std::array is just a re-export of core::array, you can replace std::array::from_fn() with core::array::from_fn().

I used std in the example, because you are using Arc, which is not available in no_std anyway, so I was assuming std is available for you.

Nitpick: Arc is defined/available in alloc, so it can be used in any environment that has a global allocator and supports atomic pointers. The original code used std paths, of course, but you can use Arc without touching std if you want to.

1 Like

Oh, I didn't realise that. Thanks.

Ok, so this throws a kind of spanner in the works :slight_smile: .

So, what I'm trying to achieve is this:

  1. I want to write a driver that manages a chain (N) of 74hc594d shift registers via an SPI peripheral.
  2. The driver maintains a u8 array of the state that can be shifted out.
  3. The split API, returns: a bunch of RegisterPins that control bit in the u8 array and a ShiftRegister
  4. The RegisterPins will be owned by various client components (possibly across different threads)
  5. The owner of the ShiftRegister can shift out the u8 array at regular intervals.

I want the driver to be as generic as possible, i.e. in environments with or without threads.

Without using Arc... how could I share a u8 array across multiple threads?

Thank you.

to abstract away thread safety is quite difficult. you need Send and/or Sync for multi-thread access, but on the other hand, you don't want to add unnecessary restrictions/overhead when it is not needed. for example, Arc has more overhead (because of atomic operations) than Rc, but Rc is explicitly !Send and !Sync so it cannot be accessed from multiple threads.

personally I don't like to over-engineer for "truly generic design", but would rather prefer a simple design, and just let the user wrap it with Arc and Mutex if it cannot be trivially made thread safe.

as @kpreid mentioned, Arc is actually defined in alloc, and std simply re-exports it. it can be used even for no_std, as long as a suitable Allocator is available. moreover, if you are concerning threading, you are already on std (or at least alloc) anyway.