Speed up solution to 3 Sum problem

Looking for suggestions to speed up my solution to the 3 Sum problem on leetcode. Goal is finding all unique combos of 3 values from a vec of ints that add up to 0. My solution runs in about 340ms on the leetcode site. My assumption is that converting from a HashSet is faster than deduping a Vec, but I'm also sure there must be a way to short circuit the introduction of duplicates in the first place.

    pub fn three_sum(nums: Vec<i32>) -> Vec<Vec<i32>> {
        
        if nums.len() < 3 { return vec!() }
        let mut nums = nums;
        
        nums.sort();
        
        nums.iter()
            .take_while(|target| **target <= 0)
            .enumerate()
            .filter(|(i,target)| *i == 0 || nums[i-1] != **target)
            .flat_map(|(i,target)| {         
                    
                    let twosum = Solution::two_sum(nums[i+1..].to_vec(), *target);
                    twosum.into_iter().collect::<Vec<Vec<i32>>>()
                

            })
            .collect::<Vec<Vec<i32>>>()
    }
    
    pub fn two_sum(nums: Vec<i32>, target: i32) -> HashSet<Vec<i32>> {
        let mut memo: HashSet<i32> = HashSet::new();

        let mut res: HashSet<Vec<i32>> = HashSet::new();
        
        for n in nums.iter() {
            match memo.get(&(-1*(target + n))) {
                None => {memo.insert(*n);},
                Some(&j) => { res.insert(vec!(target,*n,j));
                }
            }
        }
        
        res
    }
2 Likes

On first glance...

Use tuples instead of vecs for storing the groups of the elements, to avoid an extra heap allocation.

You may see a performance improvement by using the HashMap::entry API instead of get followed by insert.

If you preprocess the input vector, numbers, to only hold at most 3 of any given value, then sort the vector, I think you can just do:

let mut solutions = vec![];

for i in numbers.len() {
    for j in (i + 1)..numbers.len() {
        for k in (j + 1)..numbers.len() {
            let (i, j, k) = (numbers[i], numbers[j], numbers[k]);
            if (i + j + k) == 0 {
                solutions.push((i, j, k));
            }
        }
    }
}

By preprocessing, you can make it so the invariants are guaranteed to be upheld, then just do trivially simple and fast operations from that point on.

Something like that. This is just a sketch of the idea.

--- ah maybe it doesn't work. But I think there's a kernel of truth in it. I'll come back to it.

I don't think this will be faster b/c it loops through nearly the whole sequence 3 times -- O(n^3) I think.

good idea re: entry(). Not sure about tuples as the problem specifies the return type as Vec<Vec<i32>>

To get every triple, you have to get every triple.

But yeah I will come back with the corrected version of what I was thinking. I don't think you need any allocations.

But you can skip some triples. For instance if nums[i] == nums[j] then you can skip the value of i' == j. (Sorry if I'm butchering the notation, but hopefully it's clear what I mean.)

So not just the triples are unique, but also the elements within any given triple?

I.e. you can't have (1, 1, 2) or such.

No the triple [1,1,-2] would be valid, but the pool of potential candidates may be:
nums = [1,1,1,-2]. So (nums[0], nums[1], nums[3]) is valid but equivalent to (nums[1], nums[2], nums[3]). Since nums[0] == nums[1] if a solution is found for nums[0], you can skip nums[1] as the initial value of the triple as any solution will be a duplicate.

1 Like

Thanks for clarifying. I still think there's a way around that.

You're right. You could do something with a simple guard clause:

for i in 0..nums.len()-2 {

    if i > 0 && nums[i] == nums[i-1] { continue }

    for j in i+1..nums.len()-1 {
        
        for k in j+1..nums.len() {

          valid_triple(nums[i], nums[j], nums[k])

        }

    }

}

but I'm stubbornly trying to find an iterator based solution.

40-50% improvement by using with_capacity to avoid reallocations for my HashSets. Because HashSet does not have the entry method that HashMap does, I have not used that approach yet.

    pub fn two_sum(nums: Vec<i32>, target: i32) -> HashSet<Vec<i32>> {
        let mut memo: HashSet<i32> = HashSet::with_capacity(nums.len());

        let mut res: HashSet<Vec<i32>> = HashSet::with_capacity(nums.len()/3);
        
        for n in nums.iter() {
            match memo.get(&(-1*(target + n))) {
                None => {memo.insert(*n);},
                Some(&j) => { res.insert(vec!(target,*n,j));
                }
            }
        }
        
        res
    }

runtime down to 184ms
Unfortunately initiating the final result with Vec::with_capacity(nums.len()/3) and then appending the two_sum results to it does not yield much further improvement.


Managed to get significant improvement abandoning the HashSet and using a sort of sliding window approach:

pub fn three_sum(nums: Vec<i32>) -> Vec<Vec<i32>> {
        
        if nums.len() < 3 { return vec!() }
        let mut nums = nums;
        
        nums.sort();
        
        let mut res: Vec<Vec<i32>> = Vec::with_capacity(nums.len()/3);
        
        nums.iter()
            .take_while(|target| **target <= 0)
            .enumerate()
            .filter(|(i,target)| *i == 0 || nums[i-1] != **target)
            .for_each(|(i,target)| {         
                    let mut left_index = i+1;
                    let mut right_index = nums.len()-1;
                    while left_index < right_index {
                        let left = nums[left_index];
                        let right = nums[right_index];
                        match (target + left + right).cmp(&0) {
                            Ordering::Less => { left_index += 1},
                            Ordering::Greater => { right_index -= 1},
                            Ordering::Equal => { res.push(vec!(*target,left,right)); left_index += 1; right_index -= 1; },
                        }
                    }
                }
            );
                    
            res.dedup();        
            res
    }

Down to 24ms but need to figure out how to avoid that dedup step.

Where can I find the test cases and benchmarks? I tested this with a vector of my own, but it could be subtly wrong.

I like imperative code.

I used some special cases which duplicated some code and I'm sure all this could fit into some higher ideal of code beauty or whatever, but I was aiming for speed. Should be O(n*log2(n) + (n^3)/4). Or is it (n ^ 3) / 6? The volume of the triangularly-based pyramid when you take 3 points on a cube.

    numbers.sort();

    let mut solutions = vec![];

    let mut i = 0;

    while i < (numbers.len() - 2) {
        let (x, y, z) = (numbers[i], numbers[i + 1], numbers[i + 2]);

        if (x == y) && (y == z) {
            if (x + y + z) == 0 {
                solutions.push((x, y, z));
            }

            while numbers[i + 2] == x {
                i += 1;
            }
        } else if (x == y) {
            let mut k = i + 2;

            while k < numbers.len() {
                let z = numbers[k];

                if ((k + 1) < numbers.len()) && (numbers[k + 1] == z) {
                    // skip
                } else {
                    if (x + y + z) == 0 {
                        solutions.push((x, y, z));
                    }
                }

                k += 1;
            }
            i += 1;
        } else {
            let mut j = i + 1;

            while j < (numbers.len() - 1) {
                let y = numbers[j];

                if numbers[j + 1] != y {
                    let mut k = j + 1;

                    while k < numbers.len() {
                        let z = numbers[k];
                        let h = k + 1;

                        if (h < numbers.len()) && (numbers[h] == z) {
                            // skip
                        } else {
                            if (x + y + z) == 0 {
                                solutions.push((x, y, z));
                            }
                        }

                        k += 1;
                    }
                }
                j += 1;
            }

            i += 1;
        }
    }

The basic idea is to recognize runs of 3 of the same value, handle them specially, then skip ahead till you only have 2 of that same value in a row. Then you handle cases of 2 in a row specially, by iterating k but not iterating j. After that, you handle the other cases.

Or even more tl;dr, skip the duplicates ahead of time.

You'll be able to run your code against their 300+ tests but you cannot see the tests. The benchmarks are a comparison to all valid submissions to the site for that problem and language.

Requires an account? Lame. Meh, I'm satisfied. Thanks for the challenge and I hope I contributed at least an idea.

Is your solution already O(n^2 * hashlookup) time? It looks like to solve a+b+c = 0, you're

(1) computing a+b for all pairs (a,b) // this is O(n^2)
(2) doing a hash lookup to see if -(a+b) is in the vec // cost here is hashlookup

If so -- are you looking for algorithmic improvements or constant factor shaving?

I think you are correct about my original solution, but the latest solution I don't think is quite O(n^2) because I'm skipping possible duplicate pairs.

Isn't res automatically ordered on the tuples it contains? So you'd only have to compare with the last element before pushing to keep it free of duplicates? You could just filter left the way you filter target, then you can be sure you don't have any duplicates.

Oh, and to shave of a tiny bit, you could use sort_unstable.

You're right about eliminating the dedup step by filtering out left duplicates early. Surprisingly not much if any speed improvement.
Added this to incrementing left_index on finding a solution:

while left_index < right_index && nums[left_index] == left {
                                    left_index += 1;
                                } 

Using sort_unstable didn't really improve time either.

The only way I can see this being valid is if you can use the exact same integer (at the same index) more than once in a triple. Otherwise, the best algorithm I can envision is worst-case (n*(n/4)*log2n) and best case n, depending on the density of unique values in the vector as well as the proportion of negatives to positives.

I assumed it was the latter case from the wording of the problem, that you can only use the same integer twice in a triple if it appears twice in the input vector, or three times if it appears three times.

To guarantee that you don't reuse the same integer, I would take a hybrid approach of your original algorithm and the skipping logic I proposed above. Walk both ends of the vector inward, and perform a binary search between the left and right indexes while searching. For both the left and right indexes, skip duplicates while iterating.

The algorithm with the above high-level description would also not need a deduplication step.


If my understanding of the problem is correct, how are you preventing using the same integer at the same index more than once?