I'm working on porting some code I wrote for my thesis to Rust. Something I need to get right and fast is an enumeration over all non-negative integer-valued vectors x satisfying dot_product(a,x)=k for positive integer k and positive-integer-valued vector a.
Here is my first attempt in Rust, with 3 "pain points" commented:
// The goal is to create a data structure that iterates over non-negative, integer-valued
// points on the plane a1*x1+a2*x2+...=k.
// BIG ASSUMPTION: a1 divides k
struct NonnegativeLattice {
a: Vec<u64>,
x: Vec<u64>,
init: bool,
}
impl NonnegativeLattice {
fn new(a: Vec<u64>, k: u64) -> NonnegativeLattice {
let mut x = vec![0; a.len()];
x[0] = k / a[0];
NonnegativeLattice { a: a, x: x, init: true, }
}
}
impl Iterator for NonnegativeLattice {
type Item = Vec<u64>;
fn next(&mut self) -> Option<Self::Item> {
// Pain point #1: awkward inclusion of lower bound, though how else to do this?
if self.init {
self.init = false;
Some(self.x.clone())
}
else {
let mut success = false;
// Pain point #2: this is glorified C -- what is the idiomatic way to do this?
// I hate the thought of all the bounds checks slowing this down.
// * the "algorithm" works by treating x[0] as a slush fund,
// adding/removing mass as necessary to preserve the dot product
for i in 1..self.x.len() {
if self.x[0] >= self.a[i] {
self.x[0] -= self.a[i];
self.x[i] += 1;
success = true;
break;
}
self.x[0] += self.x[i] * self.a[i];
self.x[i] = 0;
}
if success {
// Pain point #3: it feels weird to clone a vector for each iteration...
// I can promise that the returned value isn't modified and doesn't live past each
// iteration -- would a lifetimed reference work here instead? I couldn't figure it out
Some(self.x.clone())
}
else { // equivalent: all entries in x are now 0
None
}
}
}
}
fn main() {
for c in NonnegativeLattice::new(vec![1,2,3], 10) {
println!("{:?}", c);
}
}
Any advice would be much appreciated. The code works, but now that I'm not working on a time-crunch I want to do things the Right Way. Thanks!
I think the .clone() is required, the promise you describe sounds like a streaming iterator, and they are not really working there are ways to do it but they are not general inofe for std and not ergonomic as a crate.
For the sake of completion, here is my current approach taking into consideration the advice given. I was able to address pain point #1 with a clever wrap around, #2 with jethrogb's suggestion, and #3 using a raw pointer (?!).
impl NonnegativeLattice {
// We initialize to the largest element, so that it wraps around once
// (this does more work up front but does slightly less work per each iteration, of which there are many.
// branch prediction probably made pain point #1 a moot point anyway)
fn new(a: Vec, mut k: u64) -> NonnegativeLattice {
assert_eq!(a[0], 1); // Assumption: a[0] == 1
let mut x = vec![0; a.len()];
for (n,c) in x.iter_mut().zip(a.iter()).rev() {
*n = k/c;
k -= *n * c;
}
NonnegativeLattice { a: a, x: x, init: true, }
}
}
impl Iterator for NonnegativeLattice {
// good: a raw pointer lets me use the nice iterator ergonomics of rust while keeping everything copy-free
// bad: it's a raw pointer and I'll need unsafe later
type Item = *const u64;
fn next(&mut self) -> OptionSelf::Item {
let mut success = self.init;
{
// per jethrogb's suggestion
let mut tuples = self.x.iter_mut().zip(self.a.iter());
let (x0, _) = tuples.next().expect("at least one dimension");
for (xi, ai) in tuples {
if *x0 >= *ai {
*x0 -= *ai;
*xi += 1;
success = true;
break;
}
*x0 += *xi * *ai;
*xi = 0;
}
}
if success {
self.init = false;
Some(self.x.as_ptr())
}
else {
None
}
}
}
fn main() {
for c_ptr in NonnegativeLattice::new(vec![1,2,3,4], 30) {
// here I pay the raw pointer price, though I think it's worth it
let c = unsafe { slice::from_raw_parts(c_ptr, 4) };
println!("{:?}", c);
}
}
[/code]
Do you really need to be using the Iterator trait, though? Why not just have a custom method fn next(&mut self) -> Option<&[u64]>, no need for raw pointers or clone, and use while let instead of for to iterate?
Admittedly, using Iterator lets you use convenience methods such as map and filter, but, well, it also lets you use methods like collect that are totally incorrect for iterators whose items become invalid after further iteration. The trait's not designed for that.
So I used exactly that approach in my C code. I also need to keep track of the integer count, so enumerate() gets me that for free along with the beautiful ergonomics of the Iterator, while the while requires me to keep track of everything -- I recall having a bug or two due to this manual book-keeping. I'm on the fence... what I'm doing is definitely "wrong." However, it is very readable and easy to reason about if I promise to only call next() on it -- and the NonnegativeLattice is a private structure used by other public-facing things I will write, so I can make that promise. So, I can use/abuse the Rust's iterator ergonomics to make the code more apparently bug-free in one sense, but simultaneously add a surface for other bugs to manifest due to the memory dereferencing, or I can do all of the book-keeping myself and leave myself open to bugs from that. I think you're probably right at the end of the day... I'll code up your suggestion and see how it feels.
Then again, given that the iterator is over *const u64, one needs to use unsafe somewhere. Given that unsafe must be used, and/or collecting into a vector of raw pointers is kind of fishy to begin with, is it really a violation of the spirit of the iterator trait? (I don't know, I am just a newb.)
Essentially, using raw pointer is essentially giving up with Rust type system, and fine, sometimes Rust type system is not flexible enough.
However, in your case, you are trying to return a borrowed reference. Unfortunately, Iterator doesn't support that use-case, but there is an easy way out - don't use Iterator trait, it's sure convenient, but your code doesn't work with it. Instead, provide a next method in your struct. Then, you can use it like this.
while let Some(c) = lattice.next() {
println!("{:?}", c);
}
Sure, it's not as pretty as for-in, but it works. If you want to zip that with some other iterator, you can do something like this.
let mut indexes = 0..;
while let Some((c, i)) = lattice.next().iter().zip(&mut indexes).next() {
println!("{} {:?}", i, c);
}
You may want to provide a convenience function for this, because it's not exactly pretty (and Clippy complains about that, suggesting a fix that would break this code, which is a known bug in Clippy).
I think you convinced me -- I wanted to use the nice for loop ergonomics, but seeing the while/let and how nice that is too, it seems like I can have my cake and eat it too.