Knuth Morris Pratt algorithm with C++ and Rust [Solved]

#1

I’m try to solve string searching problem using Rust on this site, I’ve tried to “translate” my previous accepted code from C++ to Rust. For the C++ answer, I can pass all of the testcase.

But, when I use Rust, I can’t pass the upper corner testcase (find 10000 length string in 1000000 length string) and I keep getting Time Limit Exceed even when I try use someone algorithm.

So am I doing something wrong for the Rust code? Or is this is not a good problem for Rust?

This is my code in C++:

#include <iostream>
//#include <chrono>

  void computeLPSArray(std::string * pattern, int M, int * lps) {
    int len = 0;
    lps[0] = 0;
    int i = 1;
    while (i < M) {
      if ((*pattern)[i] == (*pattern)[len]) {
        len++;
        lps[i] = len;
        i++;
      } else {
        if (len != 0) {
          len = lps[len - 1];
        } else {
          lps[i] = 0;
          i++;
        }
      }
    }
  }

void KMPSearch(std::string *pattern, std::string *string) {
  int M = pattern->length();
  int N = string->length();
  int lps[M];
  computeLPSArray(pattern, M, lps);
  int i = 0;
  int j = 0;
  while (i < N) {
    if ((*pattern)[j] == (*string)[i]) {
      j++;
      i++;
    }
    if (j == M) {
      std::cout << i - j << std::endl;
      j = lps[j - 1];
    }else if (i < N && (*pattern)[j] != (*string)[i]) {
      if (j != 0){
        j = lps[j - 1];
      }else{
        i = i + 1;
      }
    }
  }
}

int main() {
  std::string pattern, string;
  getline(std::cin, string);
  getline(std::cin, pattern);
  //std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
  KMPSearch(&pattern, &string);
  //std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
  //std::cout << std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count() << "micros" <<std::endl;
  return 0;
}

My Rust code:

fn read_line() -> String {
    let mut return_ = format!("");
    std::io::stdin().read_line(&mut return_).ok();
    return_.pop();
    return_
}

fn main() {
    let string: Vec<char> = read_line().chars().collect();
    let pattern: Vec<char> = read_line().chars().collect();
    //let start = std::time::Instant::now();
    kmp(&string, &pattern);
    //println!("{:?}", start.elapsed());
}

fn lps(pattern: &Vec<char>, m: usize, mut kmp: Vec<usize>) -> Vec<usize> {
    let mut len: usize = 0;
    kmp[0] = 0;
    let mut i: usize = 1;
    while i < m {
        if pattern[i] == pattern[len] {
            len = len + 1;
            kmp[i] = len;
            i = i + 1;
        } else {
            if len != 0 {
                len = kmp[len - 1];
            } else {
                kmp[i] = 0;
                i = i + 1;
            }
        }
    }
    kmp
}

fn kmp(string: &Vec<char>, pattern: &Vec<char>) -> () {
    let n = string.len();
    let m = pattern.len();
    let kmp = lps(&pattern, m, vec![0usize; m]);
    let mut i: usize = 0;
    let mut j: usize = 0;
    while i < n {
        if pattern[j] == string[i] {
            i = i + 1;
            j = j + 1;
        }
        if j == m {
            println!("{:?}", i - j);
            j = kmp[j - 1];
        } else if i < n && pattern[j] != string[i] {
            if j != 0 {
                j = kmp[j - 1];
            } else {
                i = i + 1;
            }
        }
    }
}

Solved code:

fn read_line() -> String {
    let mut return_ = format!("");
    std::io::stdin().read_line(&mut return_).ok();
    return_.pop();
    return_
}

fn main() {
    let string: Vec<u8> = read_line().as_bytes().to_vec();
    let pattern: Vec<u8> = read_line().as_bytes().to_vec();
    //let start = std::time::Instant::now();
    kmp(&string, &pattern);
    //println!("{:?}", start.elapsed());
}

fn kmp_table(pattern: &Vec<u8>, m: usize, mut kmp: Vec<usize>) -> Vec<usize> {
    let mut len: usize = 0;
    kmp[0] = 0;
    let mut i: usize = 1;
    while i < m {
        if pattern[i] == pattern[len] {
            len = len + 1;
            kmp[i] = len;
            i = i + 1;
        } else {
            if len != 0 {
                len = kmp[len - 1];
            } else {
                kmp[i] = 0;
                i = i + 1;
            }
        }
    }
    kmp
}

fn kmp(string: &Vec<u8>, pattern: &Vec<u8>) -> () {
    let mut out: String = String::new();
    let n = string.len();
    let m = pattern.len();
    let kmp = kmp_table(&pattern, m, vec![0usize; m]);
    let mut i: usize = 0;
    let mut j: usize = 0;
    while i < n {
        if pattern[j] == string[i] {
            i = i + 1;
            j = j + 1;
        }
        if j == m {
            out.push_str(&(i - j).to_string());
            out.push('\n');
            j = kmp[j - 1];
        } else if i < n && pattern[j] != string[i] {
            if j != 0 {
                j = kmp[j - 1];
            } else {
                i = i + 1;
            }
        }
    }
    print!("{}", out);
}
0 Likes

How to find common prefix of two byte slices effectively?
#2

Rust supports Unicode, so it’s char is 4 bytes large. That is, Vec<char> is for times larger than std::string.

I’s better to work with underlying bytes directly, try using [u8] instead of Vec<char>. To get byt s from String, use as bytes method:

https://doc.rust-lang.org/beta/std/string/struct.String.html#method.as_bytes

1 Like

#3

Still getting time limit…

fn read_line() -> String {
    let mut return_ = format!("");
    std::io::stdin().read_line(&mut return_).ok();
    return_.pop();
    return_
}

fn main() {
    //let string: &[u8] = read_line().as_bytes();
    //let pattern: &[u8] = read_line().as_bytes();
    //let start = std::time::Instant::now();
    //kmp(&string, &pattern);
    kmp(read_line().as_bytes(), read_line().as_bytes());
    //println!("{:?}", start.elapsed());
}

fn kmp_table(pattern: &&[u8], m: usize, mut kmp: Vec<usize>) -> Vec<usize> {
    let mut len: usize = 0;
    kmp[0] = 0;
    let mut i: usize = 1;
    while i < m {
        if pattern[i] == pattern[len] {
            len = len + 1;
            kmp[i] = len;
            i = i + 1;
        } else {
            if len != 0 {
                len = kmp[len - 1];
            } else {
                kmp[i] = 0;
                i = i + 1;
            }
        }
    }
    kmp
}

fn kmp(string: &[u8], pattern: &[u8]) -> () {
    let n = string.len();
    let m = pattern.len();
    let kmp = kmp_table(&pattern, m, vec![0usize; m]);
    let mut i: usize = 0;
    let mut j: usize = 0;
    while i < n {
        if pattern[j] == string[i] {
            i = i + 1;
            j = j + 1;
        }
        if j == m {
            println!("{:?}", i - j);
            j = kmp[j - 1];
        } else if i < n && pattern[j] != string[i] {
            if j != 0 {
                j = kmp[j - 1];
            } else {
                i = i + 1;
            }
        }
    }
}
0 Likes

#4

Another thing which looks suspicious is

println!("{:?}", i - j);

This locks stdout and flushes it, which is slow (it does a syscall, and syscalls are slow because you need to go from user-space to kernel-space). I think returning a Vec<usize> from kmp, formatting it to one big String in main and doing a single println should be much faster for cases where every position is a match. Note that C++ version does std::cout << i - j << std::endl; which suffers from a similar problem.

3 Likes

#5

Rust’s string-format is actually also fairly slow (I don’t actually know why). So the combination of flushing stdout (which the C++ solution also suffers from) and performing a format on every success (which the C++ solution does not suffer from) is probably the problem.

0 Likes

#6

Thank you so much for the information! It’s solved :smile:.

0 Likes

#7

What did you do to solve it?

Another thing I realized while looking at this is that Rust is doing bounds-checking, possibly for every array index operation. These can be eliminated as an optimization if LLVM can prove that they’re unnecessary, but if that’s not happening, then they could cause quite a bit of slowdown. This is avoidable using unsafe unchecked inexing.

0 Likes

#8

I just remove the println!, and create an out variable as a temporary string container, concatenate everything that I want to print and print the single string out variable in the end using print!. So yeah as @matklad said, using println! for many times slow down my code.

Check the post Solved code.

0 Likes

#9

Heh, I think “bounds checks can be eliminated with unsafe” is a pretty dangerous advice to give, at least without dumping a ton of context and nuance as well.

Accessing a slice out of bounds is UB and security vulnerability, and humans are pretty bad at making sure indexing operations are correct. So, get_unchecked has a very high cost.

It’s important to understand that the cost of bounds check itself is negligible. It’s a trivially predicted branch with a trivial condition. That is, CPU assumes(using speculative optimization) that index is in bounds and only has to do extra work if it isn’t.

What can be costly is the missed opportunity for optimization. For example, bounds checking can prevent autovectotization from occurring. In this cases though, get_unchecked is usually a wrong hammer to swing! If you relay on SIMD for speed, it’s better to make this explicit and use some of the simd abstraction libraries. If you feel uncomfortable using them, autovectorization can sometimes be made more reliable by using iterators instead of manual indexing

1 Like

#10

Okay, fair enough; sorry! I think it’s somewhat reasonable in code like this where the checked version is already known to work.

I understand what you’re saying, but I’d like to see some numbers. I thought I remembered Joe Duffy writing that un-elided bounds checks being a fairly major source of performance difficulties in the Midori project, but unfortunately I can’t find find the post.

0 Likes

#11

but I’d like to see some numbers. I

Good call, here are some benchmarks: https://github.com/matklad/bounds-check-cost

1 Like

#12

Probably http://joeduffyblog.com/2015/12/19/safe-native-code/.

Truth is both Joe and @matklad are right :slight_smile:. Besides potentially killing loop optimizations, C# has the added issue of needing a memory load to get an array’s length - if that misses in cache, you stall badly. Rust is better in this regard because sliced have the length in the fat ptr, which will likely be kept in a register or at worst, spilled to a stack slot.

But I agree with @matklad - biggest issue will be missed optimizations in some cases, not too dissimilar to inlining failures; for them, it’s rare that it’s the actual call (frame setup/teardown) that’s costly, but rather missed further optimizations.

2 Likes

#13

I made a branch using the ‘solved’ code above: https://github.com/BatmanAoD/bounds-check-cost/tree/knuth_morris_pratt

Based on results on my machine, it looks like indeed the cost of bounds-checking is not significant in this case.

1 Like