Question related to recursion and lifetime

Hi, all. I want to ask two questions related to the code below:

  1. Is it possible to simplify dfs1() and dfs2() to dfs() by generic?
  2. How can I modify my code to resolve compilation errors if answer to question 1 is yes?

Thanks!

use std::cell::RefCell;
use std::iter::Rev;
use std::rc::Rc;
use std::slice::Iter;

struct DfsContext<'a> {
    depth: i32,
    // sum of distances between all nodes visited and current node.
    prev_sum: i32,
    // counter of visited nodes.
    visited_nodes: i32,
    parent: u16,
    res: &'a mut Vec<i32>,
    adj_list: &'a Vec<Vec<u16>>,
}

trait DfsStage<'a> {
    type NeighIter: Iterator<Item = &'a u16>;

    fn modify_iter(iter: Iter<'a, u16>) -> Self::NeighIter;
    fn res_update_before_recur(ctx: &mut DfsContext, neigh: u16) {}
    fn res_update_after_recur(ctx: &mut DfsContext, neigh: u16, child_sum: i32) {}
}

struct DfsStage1 {}

impl<'a> DfsStage<'a> for DfsStage1 {
    type NeighIter = Iter<'a, u16>;

    fn modify_iter(iter: Iter<'a, u16>) -> Self::NeighIter {
        iter
    }

    fn res_update_after_recur(ctx: &mut DfsContext, neigh: u16, child_sum: i32) {
        ctx.res[neigh as usize] = ctx.prev_sum + child_sum;
    }
}

struct DfsStage2 {}

impl<'a> DfsStage<'a> for DfsStage2 {
    type NeighIter = Rev<Iter<'a, u16>>;

    fn modify_iter(iter: Iter<'a, u16>) -> Self::NeighIter {
        iter.rev()
    }

    fn res_update_before_recur(ctx: &mut DfsContext, neigh: u16) {
        let delta = ctx.prev_sum - ctx.depth * (ctx.depth + 1) / 2;
        ctx.res[neigh as usize] += delta;
    }
}

impl Solution {
    // Return sum of distances between all children of node and node.
    fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
        let my_parent = ctx.parent;
        ctx.parent = node;

        let prev_sum_save = ctx.prev_sum;
        // Related to node's children.
        ctx.visited_nodes += 1;
        ctx.prev_sum += ctx.visited_nodes;

        // Related to node.
        let mut total_child_sum = 0;

        let adj_list_iter = ctx.adj_list[node as usize].iter();
        let mut adj_list_iter = S::modify_iter(adj_list_iter);

        ctx.depth += 1;
        for neigh in adj_list_iter {
            if *neigh == my_parent {
                continue;
            }

            S::res_update_before_recur(ctx, *neigh);

            let visited_nodes_before = ctx.visited_nodes;

            // Related to each child of node.
            let child_sum = Self::dfs::<'a, S>(ctx, *neigh);

            let child_count = ctx.visited_nodes - visited_nodes_before;
            total_child_sum += child_sum + child_count - 1 + 1;

            S::res_update_after_recur(ctx, *neigh, child_sum);

            // add previous siblings' distances.
            ctx.prev_sum += child_sum + (child_count - 1) * 2 + 2;
        }
        ctx.depth -= 1;

        ctx.prev_sum = prev_sum_save;
        ctx.parent = my_parent;

        total_child_sum
    }

    // Return sum of distances between all children of node and node.
    fn dfs1(ctx: &mut DfsContext, node: u16) -> i32 {
        let my_parent = ctx.parent;
        ctx.parent = node;

        let prev_sum_save = ctx.prev_sum;
        // Related to node's children.
        ctx.visited_nodes += 1;
        ctx.prev_sum += ctx.visited_nodes;

        // Related to node.
        let mut total_child_sum = 0;

        ctx.depth += 1;
        ctx.adj_list[node as usize].iter().for_each(|neigh| {
            if *neigh == my_parent {
                return;
            }

            let visited_nodes_before = ctx.visited_nodes;

            // Related to each child of node.
            let child_sum = Self::dfs1(ctx, *neigh);

            let child_count = ctx.visited_nodes - visited_nodes_before;
            total_child_sum += child_sum + child_count - 1 + 1;

            ctx.res[*neigh as usize] = ctx.prev_sum + child_sum;

            // add previous siblings' distances.
            ctx.prev_sum += child_sum + (child_count - 1) * 2 + 2;
        });
        ctx.depth -= 1;

        ctx.prev_sum = prev_sum_save;
        ctx.parent = my_parent;

        total_child_sum
    }

    fn dfs2(ctx: &mut DfsContext, node: u16) -> i32 {
        let my_parent = ctx.parent;
        ctx.parent = node;

        let prev_sum_save = ctx.prev_sum;
        // Related to node's children.
        ctx.visited_nodes += 1;
        ctx.prev_sum += ctx.visited_nodes;

        // Related to node.
        let mut total_child_sum = 0;

        ctx.depth += 1;
        ctx.adj_list[node as usize].iter().rev().for_each(|neigh| {
            if *neigh == my_parent {
                return;
            }

            // Add prev_sum without all ancestor.
            let delta = ctx.prev_sum - ctx.depth * (ctx.depth + 1) / 2;
            ctx.res[*neigh as usize] += delta;

            let visited_nodes_before = ctx.visited_nodes;

            // Related to each child of node.
            let child_sum = Self::dfs2(ctx, *neigh);

            let child_count = ctx.visited_nodes - visited_nodes_before;
            total_child_sum += child_sum + child_count - 1 + 1;

            // add previous siblings' distances.
            ctx.prev_sum += child_sum + (child_count - 1) * 2 + 2;
        });
        ctx.depth -= 1;

        ctx.parent = my_parent;
        ctx.prev_sum = prev_sum_save;

        total_child_sum
    }

    pub fn sum_of_distances_in_tree(n: i32, edges: Vec<Vec<i32>>) -> Vec<i32> {
        let mut adj_list: Vec<Vec<u16>> = vec![Vec::new(); n as usize];
        edges.iter().for_each(|edge| {
            adj_list[edge[0] as usize].push(edge[1] as u16);
            adj_list[edge[1] as usize].push(edge[0] as u16);
        });

        let mut res = vec![0; n as usize];
        let mut ctx = DfsContext {
            depth: 0,
            prev_sum: 0,
            visited_nodes: 0,
            parent: u16::MAX,
            adj_list: &adj_list,
            res: &mut res,
        };

        let root_sum = Self::dfs::<DfsStage1>(&mut ctx, 0);
        ctx.res[0] = root_sum;

        ctx.prev_sum = 0;
        ctx.visited_nodes = 0;
        Self::dfs::<DfsStage2>(&mut ctx, 0);

        res
    }
}

#[cfg(test)]
mod test {
    use super::*;

    fn do_case(n: i32, edges: &[[i32; 2]], expect: &[i32]) {
        let edges: Vec<Vec<i32>> = edges.iter().map(|edge| vec![edge[0], edge[1]]).collect();
        let res = Solution::sum_of_distances_in_tree(n, edges);
        assert_eq!(res, expect);
    }

    #[test]
    fn case1() {
        do_case(
            6,
            &[[0, 1], [0, 2], [2, 3], [2, 4], [2, 5]],
            &[8, 12, 6, 10, 10, 10],
        );
    }

    #[test]
    fn case2() {
        /*
        do_case(
            7,
            &[[0, 1], [0, 2], [2, 3], [2, 4], [2, 5], [4, 6]],
            &[8, 12, 6, 10, 10, 10],
        );
        */
    }

    #[test]
    fn case3() {
        do_case(1, &[], &[0]);
    }

    #[test]
    fn case4() {
        do_case(2, &[[1, 0]], &[1, 1]);
    }

    #[test]
    fn case5() {}
}

#[allow(unused)]
struct Solution {}

I encountered following compilation error:

PS C:\Users\user\rs\rust-test> cargo test
   Compiling rust-test v0.1.0 (C:\Users\user\rs\rust-test)
warning: cannot specify lifetime arguments explicitly if late bound lifetime parameters are present
  --> src\main.rs:82:41
   |
56 |     fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
   |                                              ---------- the late bound lifetime parameter is introduced here
...
82 |             let child_sum = Self::dfs::<'a, S>(ctx, *neigh);
   |                                         ^^
   |
   = warning: this was previously accepted by the compiler but is being phased out; it will become a hard error in a future release!
   = note: for more information, see issue #42868 <https://github.com/rust-lang/rust/issues/42868>
   = note: `#[warn(late_bound_lifetime_arguments)]` on by default

warning: unused variable: `ctx`
  --> src\main.rs:21:32
   |
21 |     fn res_update_before_recur(ctx: &mut DfsContext, neigh: u16) {}
   |                                ^^^ help: if this is intentional, prefix it with an underscore: `_ctx`
   |
   = note: `#[warn(unused_variables)]` on by default

warning: unused variable: `neigh`
  --> src\main.rs:21:54
   |
21 |     fn res_update_before_recur(ctx: &mut DfsContext, neigh: u16) {}
   |                                                      ^^^^^ help: if this is intentional, prefix it with an underscore: `_neigh`

warning: unused variable: `ctx`
  --> src\main.rs:22:31
   |
22 |     fn res_update_after_recur(ctx: &mut DfsContext, neigh: u16, child_sum: i32) {}
   |                               ^^^ help: if this is intentional, prefix it with an underscore: `_ctx`

warning: unused variable: `neigh`
  --> src\main.rs:22:53
   |
22 |     fn res_update_after_recur(ctx: &mut DfsContext, neigh: u16, child_sum: i32) {}
   |                                                     ^^^^^ help: if this is intentional, prefix it with an underscore: `_neigh`

warning: unused variable: `child_sum`
  --> src\main.rs:22:65
   |
22 |     fn res_update_after_recur(ctx: &mut DfsContext, neigh: u16, child_sum: i32) {}
   |                                                                 ^^^^^^^^^ help: if this is intentional, prefix it with an underscore: `_child_sum`

warning: variable does not need to be mutable
  --> src\main.rs:69:13
   |
69 |         let mut adj_list_iter = S::modify_iter(adj_list_iter);
   |             ----^^^^^^^^^^^^^
   |             |
   |             help: remove this `mut`
   |
   = note: `#[warn(unused_mut)]` on by default

error[E0503]: cannot use `ctx.visited_nodes` because it was mutably borrowed
  --> src\main.rs:84:31
   |
56 |     fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
   |            -- lifetime `'a` defined here
...
82 |             let child_sum = Self::dfs::<'a, S>(ctx, *neigh);
   |                             -------------------------------
   |                             |                  |
   |                             |                  `*ctx` is borrowed here
   |                             argument requires that `*ctx` is borrowed for `'a`
83 |
84 |             let child_count = ctx.visited_nodes - visited_nodes_before;
   |                               ^^^^^^^^^^^^^^^^^ use of borrowed `*ctx`

error[E0499]: cannot borrow `*ctx` as mutable more than once at a time
  --> src\main.rs:87:39
   |
56 |     fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
   |            -- lifetime `'a` defined here
...
82 |             let child_sum = Self::dfs::<'a, S>(ctx, *neigh);
   |                             -------------------------------
   |                             |                  |
   |                             |                  first mutable borrow occurs here
   |                             argument requires that `*ctx` is borrowed for `'a`
...
87 |             S::res_update_after_recur(ctx, *neigh, child_sum);
   |                                       ^^^ second mutable borrow occurs here

error[E0503]: cannot use `ctx.prev_sum` because it was mutably borrowed
  --> src\main.rs:90:13
   |
56 |     fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
   |            -- lifetime `'a` defined here
...
82 |             let child_sum = Self::dfs::<'a, S>(ctx, *neigh);
   |                             -------------------------------
   |                             |                  |
   |                             |                  `*ctx` is borrowed here
   |                             argument requires that `*ctx` is borrowed for `'a`
...
90 |             ctx.prev_sum += child_sum + (child_count - 1) * 2 + 2;
   |             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of borrowed `*ctx`

Some errors have detailed explanations: E0499, E0503.
For more information about an error, try `rustc --explain E0499`.
warning: `rust-test` (bin "rust-test" test) generated 7 warnings
error: could not compile `rust-test` (bin "rust-test" test) due to 3 previous errors; 7 warnings emitted
PS C:\Users\user\rs\rust-test> 

You need a &mut DfsContext<'a>, not a &'a mut DfsContext<'_>.

-    fn dfs<'a, S: DfsStage<'a>>(ctx: &'a mut DfsContext, node: u16) -> i32 {
+    fn dfs<'a, S: DfsStage<'a>>(ctx: &mut DfsContext<'a>, node: u16) -> i32 {

And to appease a forward-compatibility warning:

             // Related to each child of node.
-            let child_sum = Self::dfs::<'a, S>(ctx, *neigh);
+            let child_sum = Self::dfs::<S>(ctx, *neigh);

Does that answer your questions?

1 Like

Using <'a> for everything is unclear, and can easily backfire when you make it refer to multiple incompatible things. You can use more descriptive names to describe which loaned data the lifetimes refer to, like <'res, 'adj>.

Add this to the top of your lib.rs and main.rs:

#![warn(elided_lifetimes_in_paths)]

and fix these warnings. For backwards compatibility Rust allows implicit syntax that hides which structs are temporary views into data they don't store.

Also, never reuse lifetime labels of &mut loans. These loans are exclusive and invariant, meaning they're maximally strict and inflexible about their scope. When you slap the same 'a on multiple things, you're saying they all have to be compatible, they must have been borrowed from the same place at the same time, and have to stay exclusively locked together for exactly as long. This is usually incorrect and way too strict, and it paralyzes the whole code into uselessness. OTOH same &'a on shared loans is generally okay, because the compiler is allowed to unify and shorten such lifetimes where necessary.

3 Likes

Thanks for your advice. The rust book should tell us these caveats. :grinning:

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.