Trouble with lifetime and generic function; how can I fix this?

Hi all!

I'd like to define two traits in my codebase (one trait being a supertrait of the other) and have a generic function bounded by this trait.

Those two traits are to be implemented by datasets structs. The first one simply defines how to get an iterator over the elements of the dataset. The second trait requires the first one and defines how to get batches of elements. Due to the nature of how they are supposed to be used, there is a lot of references implicated and thus lifetimes. I unfortunately fail to properly set those lifetimes and I cannot find a proper way of doing this...

Here's the code in question:

use std::{
    iter::{Cloned, Take, Zip},
    slice::{Chunks, Iter},
};

const BATCH_SIZE: usize = 2;
const CHUNK_SIZE: usize = 2;

// Example struct containing a dataset. Note that the data of a single element has
// more than one f32; the `data` vector contains pixels of multiple images
// concatenated while a single `u8` represents the label of a single element.
struct Dataset {
    label: Vec<u8>,
    data: Vec<f32>,
}

// Type alias to simplify the implementation. This is the iterator returned by
// the first trait; a cloned copy of an element's label zipped with a slice of data.
// This slice of data contains the pixels of an image in a large vector containing
// many images concatenated one after the other. This iterator allows us to
// iterate over pairs of label-image in the dataset.
type ZipChunks<'a, L, D> = Zip<Cloned<Iter<'a, L>>, Chunks<'a, D>>;

// Trait to return an interator over (label,image) of the dataset
pub trait LabeledDataset<'a> {
    type Label;
    type Data;

    fn iter_train(&'a self) -> ZipChunks<'a, Self::Label, Self::Data>;
}

// Trait to return batches of elements of the dataset
pub trait BatchedLabeledDataset<'a, 'b>: LabeledDataset<'a> {
    fn batched_iter_train(
        &'a self,
        &'b mut ZipChunks<'a, Self::Label, Self::Data>,
    ) -> Take<&'b mut ZipChunks<'a, Self::Label, Self::Data>>;
}

impl<'a> LabeledDataset<'a> for Dataset {
    type Label = u8;
    type Data = f32;

    fn iter_train(&'a self) -> ZipChunks<'a, Self::Label, Self::Data> {
        self.label.iter().cloned().zip(self.data.chunks(CHUNK_SIZE))
    }
}

impl<'a, 'b> BatchedLabeledDataset<'a, 'b> for Dataset
where
    'a: 'b,
{
    fn batched_iter_train(
        &'a self,
        it: &'b mut ZipChunks<'a, Self::Label, Self::Data>,
    ) -> Take<&'b mut ZipChunks<'a, Self::Label, Self::Data>> {
        it.by_ref().take(BATCH_SIZE)
    }
}

// Function taking a reference to the dataset; works fine
fn regular_function(ds: &Dataset) {
    for _ in 0..2 {
        let mut ds_iter = ds.iter_train();
        for _ in 0..2 {
            let mut _ds_batched_iter = ds.batched_iter_train(&mut ds_iter);
        }
    }
}

// Generic function with bound on the `BatchedLabeledDataset` trait: ERROR!!!
fn generic_function<'t, 'u, DS>(ds: &'t DS)
where
    DS: BatchedLabeledDataset<'t, 'u>,
    't: 'u, // 't outlives 'u
{
    for _ in 0..2 {
        let mut ds_iter = ds.iter_train();
        for _ in 0..2 {
            let mut _ds_batched_iter = ds.batched_iter_train(&mut ds_iter);
        }
    }
}

fn main() {
    let ds = Dataset {
        label: vec![1, 2, 3],
        data: vec![1.0, 2.0, 3.0],
    };

    // Inside main
    for _ in 0..2 {
        let mut ds_iter = ds.iter_train();
        for _ in 0..2 {
            let mut _ds_batched_iter = ds.batched_iter_train(&mut ds_iter);
        }
    }

    // Inside regular function
    regular_function(&ds);

    // Inside generic function
    generic_function(&ds);
}

I get this error:

error[E0597]: `ds_iter` does not live long enough
  --> src/main.rs:81:67
   |
81 |             let mut _ds_batched_iter = ds.batched_iter_train(&mut ds_iter);
   |                                                                   ^^^^^^^ borrowed value does not live long enough
82 |         }
83 |     }
   |     - borrowed value only lives until here
   |
note: borrowed value must be valid for the lifetime 'u as defined on the function body at 73:1...
  --> src/main.rs:73:1
   |
73 | / fn generic_function<'t, 'u, DS>(ds: &'t DS)
74 | | where
75 | |     DS: BatchedLabeledDataset<'t, 'u>,
76 | |     't: 'u, // 't outlives 'u
...  |
83 | |     }
84 | | }
   | |_^

error: aborting due to previous error

For more information about this error, try `rustc --explain E0597`.

(Playground)

How can I fix? I'm out of ideas... :cry:

Thanks!!!

So somebody helped me and suggested this:

It works!

1 Like