Resolve mutable and immutable references

Hello guys. I'm still learning Rust and have some stupid questions.

I'm trying to write some simple topological sort in a way that I would in JS or Python. And of course, I get compilation errors : (

Here's the code: Rust Playground

use std::collections::HashSet;
use std::collections::HashMap;

struct Solution {}

impl Solution {
    pub fn find_order(_num_courses: i32, prerequisites: Vec<Vec<i32>>) -> Vec<i32> {
        let mut courses: HashMap<i32, HashSet<i32>> = HashMap::new();
        let mut result: Vec<i32> = Vec::new();
        
        prerequisites.iter().for_each(|x| {
            courses
                .entry(x[0])
                .or_insert(HashSet::new())
            .insert(x[1]);
        });
    
        while courses.len() > 0 {
            for (key, val) in courses.iter() {
                if val.len() == 0 {
                    courses
                        .iter_mut()
                        .for_each(
                            |(_, xval)| {
                                xval.remove(key);
                            }
                        );
                    courses.remove(key);
                    result.push(*key);
                    break;
                }
            }
        }
        
        return result;
    }
}

fn main() {
    let prerequisites: Vec<Vec<i32>> = vec![vec![1,0],vec![2,0],vec![3,1],vec![3,2]];
    let res = Solution::find_order(4, prerequisites);
    println!("{:?}", res);
}

And I'm a bit stuck at this point. Would be grateful for any suggestions.

Many thanks

Yeah that won't fly. You're looking through courses to find any values of length 0, and if you find one, you want to modify courses inside your loop. This raises questions of well-definedness in general, although I suspect your algorithm is correct.

Now, there's a way out though, since you're breaking out of the loop anyways. So just push the key to result inside your if, then break, and outside the for loop run your mutable iteration.

Like this. I've put that commented condition in a if let in proper rust fashion, otherwise one needs some auxiliary variable to track if pushing occured. I hope the comments make it clear, otherwise just ask :slight_smile:

1 Like

Here's my implementation of Kahn's algorithm:

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
struct Edge<V> {
    from: V,
    to: V,
}

#[derive(Clone, Debug, Eq, Hash, PartialEq)]
struct Graph<V> {
    edges: Vec<Edge<V>>,
}

#[derive(Clone, Debug)]
struct TopologicalSortError;

impl std::fmt::Display for TopologicalSortError {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        "graph contains at least one cycle".fmt(f)
    }
}

impl std::error::Error for TopologicalSortError {}

// helper function to swap-remove one element from vector
fn swap_remove<T, Q>(vec: &mut Vec<T>, element: &Q) -> Option<T>
where
    T: std::borrow::Borrow<Q>,
    Q: Eq,
{
    for (i, v) in vec.iter().enumerate() {
        if v.borrow() == element {
            return Some(vec.swap_remove(i));
        }
    }
    None
}

// Kahn's algorithm <https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm>
fn topological_sort<V>(graph: Graph<V>) -> Result<Vec<V>, TopologicalSortError>
where
    V: Clone + Eq + std::hash::Hash,
{
    let mut edges = graph.edges;
    let nodes: std::collections::HashSet<_> = edges
        .iter()
        .map(|edge| vec![edge.from.clone(), edge.to.clone()])
        .flatten()
        .collect();

    let mut result = vec![];
    // all nodes with no incoming edge
    let mut start_nodes: Vec<_> = nodes
        .iter()
        .filter(|&node| edges.iter().all(|edge| edge.to != *node))
        .cloned()
        .collect();

    while let Some(node) = start_nodes.pop() {
        result.push(node.clone());

        let outgoing_edges: Vec<_> = edges
            .iter()
            .filter(|&edge| edge.from == node)
            .cloned()
            .collect();
        for edge in outgoing_edges {
            swap_remove(&mut edges, &edge);

            let next_node = edge.to;
            if edges.iter().all(|edge| edge.to != next_node) {
                start_nodes.push(next_node);
            }
        }
    }

    if edges.is_empty() {
        Ok(result)
    } else {
        Err(TopologicalSortError)
    }
}

(playground)

Changes I made:

  • remove the unused struct Solution;

  • define specific types to represent edges and graphs;

  • instead of using for_each on an iterator, use for loops, which is (hopefully) more readable;

  • use generics to accept different vertex types;

  • report error if cycles are present in the graph.

This may all be overkill for a small coding exercise, though.

1 Like

Try out something like this.

1 Like

Wow, guys. You’re incredible. Thank you so much for a quick reply! Now I will need some time to process your solutions : )

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.