Text metrics in Rust


#1

Instead of using my blog, this time I post one article here.

I’ve seen a nice article that shows imperative code in pure Haskell to compute text metrics:
https://markkarpov.com/post/migrating-text-metrics.html

So I’ve tried to translate few Haskell functions to Rust language. This is just demo code, so I have tested it only lightly.

The original C language version:

unsigned int tmetrics_hamming (unsigned int len, uint16_t *a, uint16_t *b)
{
  unsigned int acc = 0, i;
  for (i = 0; i < len; i++)
    {
      if (*(a + i) != *(b + i)) acc++;
    }
  return acc;
}

The Haskell language port version:

hamming :: Text -> Text -> Maybe Int
hamming a b =
  if T.length a == T.length b
    then Just (go 0 0 0)
    else Nothing
  where
    go !na !nb !r =
      let !(TU.Iter cha da) = TU.iter a na
          !(TU.Iter chb db) = TU.iter b nb
      in if | na  == len -> r
            | cha /= chb -> go (na + da) (nb + db) (r + 1)
            | otherwise  -> go (na + da) (nb + db) r
    len = TU.lengthWord16 a

And my straightforward Rust version:

fn hamming_distance_strs(s1: &str, s2: &str) -> Option<usize> {
    let mut s1c = s1.chars();
    let mut s2c = s2.chars();
    let mut count = 0;

    loop {
        match (s1c.next(), s2c.next()) {
            (Some(c1), Some(c2)) => { if c1 != c2 { count += 1; } },
            (Some(_), None) | (None, Some(_)) => { return None; },
            (None, None) => { return Some(count); }
        }
    }
}

Unlike the Haskell version, this scans the strings only once (the Haskell T.length does it the first time), and it works on UTF-8 instead of UTF-16. Like the Haskell code it should handle the UTF variable-length chars correctly.
Perhaps you can write a faster Rust version that uses peek(), but this is more readable.

Later I’ve converted the Levenshtein distance function. The original C code:

unsigned int tmetrics_levenshtein (unsigned int la, uint16_t *a, unsigned int lb, uint16_t *b)
{
  if (la == 0) return lb;
  if (lb == 0) return la;

  unsigned int v_len = lb + 1, *v0, *v1, i, j;

  if (v_len > VLEN_MAX)
    {
      v0 = malloc(sizeof(unsigned int) * v_len);
      v1 = malloc(sizeof(unsigned int) * v_len);
    }
  else
    {
      v0 = alloca(sizeof(unsigned int) * v_len);
      v1 = alloca(sizeof(unsigned int) * v_len);
    }

  for (i = 0; i < v_len; i++)
    v0[i] = i;

  for (i = 0; i < la; i++)
    {
      v1[0] = i + 1;

      for (j = 0; j < lb; j++)
        {
          unsigned int cost = *(a + i) == *(b + j) ? 0 : 1;
          unsigned int x = *(v1 + j) + 1;
          unsigned int y = *(v0 + j + 1) + 1;
          unsigned int z = *(v0 + j) + cost;
          *(v1 + j + 1) = MIN(x, MIN(y, z));
        }

      unsigned int *ptr = v0;
      v0 = v1;
      v1 = ptr;
    }

  unsigned int result = *(v0 + lb);

  if (v_len > VLEN_MAX)
    {
      free(v0);
      free(v1);
    }

  return result;
}

The Haskell version:

levenshtein :: Text -> Text -> Int
levenshtein a b
  | T.null a = lenb
  | T.null b = lena
  | otherwise = runST $ do
      let v_len = lenb + 1
      v <- VUM.unsafeNew (v_len * 2)
      let gov !i =
            when (i < v_len) $ do
              VUM.unsafeWrite v i i
              gov (i + 1)
          goi !i !na !v0 !v1 = do
            let !(TU.Iter ai da) = TU.iter a na
                goj !j !nb =
                  when (j < lenb) $ do
                    let !(TU.Iter bj db) = TU.iter b nb
                        cost = if ai == bj then 0 else 1
                    x <- (+ 1) <$> VUM.unsafeRead v (v1 + j)
                    y <- (+ 1) <$> VUM.unsafeRead v (v0 + j + 1)
                    z <- (+ cost) <$> VUM.unsafeRead v (v0 + j)
                    VUM.unsafeWrite v (v1 + j + 1) (min x (min y z))
                    goj (j + 1) (nb + db)
            when (i < lena) $ do
              VUM.unsafeWrite v v1 (i + 1)
              goj 0 0
              goi (i + 1) (na + da) v1 v0
      gov 0
      goi 0 0 0 v_len
      VUM.unsafeRead v (lenb + if even lena then 0 else v_len)
  where
    lena = T.length a
    lenb = T.length b

And my Rust version, adapted from the C code:

use std::cmp::min;

fn levenshtein_distance_strs(t1: &str, t2: &str) -> usize {
    use std::mem::swap;

    const VLEN_MAX: usize = 15;

    fn inner(s1: &str, s2: &str, a: &mut [u32], b: &mut [u32]) -> usize {
        let (mut v0, mut v1) = (a, b);

        for i in 0 .. v0.len() {
            v0[i] = i as u32;
        }

        for (i, c1) in s1.chars().enumerate() {
            v1[0] = i as u32 + 1;

            for (j, c2) in s2.chars().enumerate() {
                let cost = if c1 == c2 { 0 } else { 1 };
                let x = v1[j] + 1;
                let y = v0[j + 1] + 1;
                let z = v0[j] + cost;
                v1[j + 1] = min(x, min(y, z));
            }

            swap(&mut v0, &mut v1);
        }

        v0[v0.len() - 1] as usize
    }

    let (mut s1, mut s2) = (t1, t2);

    if s1.is_empty() { return s2.chars().count(); }
    if s2.is_empty() { return s1.chars().count(); }

    let mut len1 = s1.chars().count();
    let mut len2 = s2.chars().count();

    if len2 > len1 {
        swap(&mut s1, &mut s2);
        swap(&mut len1, &mut len2);
    }

    let vlen = len2 + 1;
    if vlen > VLEN_MAX {
        inner(s1, s2, &mut vec![0; vlen], &mut vec![0; vlen])
    } else {
        inner(s1, s2, &mut [0; VLEN_MAX][.. vlen], &mut [0; VLEN_MAX][.. vlen])
    }
}

Like the C version this Rust code uses stack-allocated arrays when the shortest input string is small. But Rust doesn’t have something like alloca() yet, so I’ve used fixed-size arrays and I’ve sliced it. This is probably acceptable.

Using chars().count() on both strings is not a big problem here because this is an O(n^2) algorithm.

The swap of s1 and s2 is done to allocate the smallest arrays when the length of the two strings is very different. Unlike the Haskell version this Rust version contains no unsafe code. If performance is not sufficient you can replace some slice accesses with the unsafe ones (get_unchecked and get_unchecked_mut), but the assembly of this function (compiled with -C opt-level=3 -C target-cpu=native) contains efficient parts like this, so perhaps there’s no need of unsafe code:

.LBB2_51:
    vpmovzxbd   -28(%rsi), %xmm7
    vpmovzxbd   -24(%rsi), %xmm0
    vpmovzxbd   -20(%rsi), %xmm1
    vpmovzxbd   -16(%rsi), %xmm2
    vpand   %xmm4, %xmm7, %xmm7
    vpand   %xmm4, %xmm0, %xmm0
    vpand   %xmm4, %xmm1, %xmm1
    vpand   %xmm4, %xmm2, %xmm2
    vpcmpeqd    %xmm5, %xmm7, %xmm7
    vpmovzxdq   %xmm7, %ymm7
    vpand   %ymm6, %ymm7, %ymm7
    vpcmpeqd    %xmm5, %xmm0, %xmm0
    vpmovzxdq   %xmm0, %ymm0
    vpand   %ymm6, %ymm0, %ymm0
    vpcmpeqd    %xmm5, %xmm1, %xmm1
    vpmovzxdq   %xmm1, %ymm1
    vpand   %ymm6, %ymm1, %ymm1
    vpcmpeqd    %xmm5, %xmm2, %xmm2
    vpmovzxdq   %xmm2, %ymm2
    vpand   %ymm6, %ymm2, %ymm2
    vpaddq  %ymm8, %ymm7, %ymm7
    vpaddq  %ymm9, %ymm0, %ymm9
    vpaddq  %ymm10, %ymm1, %ymm10
    vpaddq  %ymm3, %ymm2, %ymm11
    vpmovzxbd   -12(%rsi), %xmm3
    vpmovzxbd   -8(%rsi), %xmm0
    vpmovzxbd   -4(%rsi), %xmm1
    vpmovzxbd   (%rsi), %xmm2
    vpand   %xmm4, %xmm3, %xmm3
    vpand   %xmm4, %xmm0, %xmm0
    vpand   %xmm4, %xmm1, %xmm1
    vpand   %xmm4, %xmm2, %xmm2
    vpcmpeqd    %xmm5, %xmm3, %xmm3
    vpmovzxdq   %xmm3, %ymm3
    vpand   %ymm6, %ymm3, %ymm3
    vpcmpeqd    %xmm5, %xmm0, %xmm0
    vpmovzxdq   %xmm0, %ymm0
    vpand   %ymm6, %ymm0, %ymm0
    vpcmpeqd    %xmm5, %xmm1, %xmm1
    vpmovzxdq   %xmm1, %ymm1
    vpand   %ymm6, %ymm1, %ymm1
    vpcmpeqd    %xmm5, %xmm2, %xmm2
    vpmovzxdq   %xmm2, %ymm2
    vpand   %ymm6, %ymm2, %ymm2
    vpaddq  %ymm7, %ymm3, %ymm8
    vpaddq  %ymm9, %ymm0, %ymm9
    vpaddq  %ymm10, %ymm1, %ymm10
    vpaddq  %ymm11, %ymm2, %ymm3
    addq    $32, %rsi
    addq    $-32, %rdx
    jne .LBB2_51

The whole Haskell code I’ve used for the comparison:

{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MultiWayIf   #-}

import Control.Monad (when)
import Control.Monad.ST (runST)
import Data.Map.Strict (Map)
import Data.Text (Text)
import GHC.Exts (inline)
import qualified Data.Text                   as T
import qualified Data.Text.Unsafe            as TU
import qualified Data.Vector.Unboxed.Mutable as VUM
import Criterion.Main (Benchmark, defaultMain, nf, bench, bgroup, env)
import Control.DeepSeq (NFData)


-- | /O(n)/ Return Hamming distance between two 'Text' values. Hamming
-- distance is defined as the number of positions at which the corresponding
-- symbols are different. The input 'Text' values should be of equal length
-- or 'Nothing' will be returned.
--
-- See also: <https://en.wikipedia.org/wiki/Hamming_distance>.
--
-- __Heads up__, before version /0.3.0/ this function returned @'Maybe'
-- 'Data.Numeric.Natural'@.

hamming :: Text -> Text -> Maybe Int
hamming a b =
  if T.length a == T.length b
    then Just (go 0 0 0)
    else Nothing
  where
    go !na !nb !r =
      let !(TU.Iter cha da) = TU.iter a na
          !(TU.Iter chb db) = TU.iter b nb
      in if | na  == len -> r
            | cha /= chb -> go (na + da) (nb + db) (r + 1)
            | otherwise  -> go (na + da) (nb + db) r
    len = TU.lengthWord16 a


-- | Return Levenshtein distance between two 'Text' values. Classic
-- Levenshtein distance between two strings is the minimal number of
-- operations necessary to transform one string into another. For
-- Levenshtein distance allowed operations are: deletion, insertion, and
-- substitution.
--
-- See also: <https://en.wikipedia.org/wiki/Levenshtein_distance>.
--
-- __Heads up__, before version /0.3.0/ this function returned
-- 'Data.Numeric.Natural'.

levenshtein :: Text -> Text -> Int
levenshtein a b = fst (levenshtein_ a b)

-- | An internal helper, returns Levenshtein distance as the first element
-- of the tuple and max length of the two inputs as the second element of
-- the tuple.

levenshtein_ :: Text -> Text -> (Int, Int)
levenshtein_ a b
  | T.null a = (lenb, lenm)
  | T.null b = (lena, lenm)
  | otherwise = runST $ do
      let v_len = lenb + 1
      v <- VUM.unsafeNew (v_len * 2)
      let gov !i =
            when (i < v_len) $ do
              VUM.unsafeWrite v i i
              gov (i + 1)
          goi !i !na !v0 !v1 = do
            let !(TU.Iter ai da) = TU.iter a na
                goj !j !nb =
                  when (j < lenb) $ do
                    let !(TU.Iter bj db) = TU.iter b nb
                        cost = if ai == bj then 0 else 1
                    x <- (+ 1) <$> VUM.unsafeRead v (v1 + j)
                    y <- (+ 1) <$> VUM.unsafeRead v (v0 + j + 1)
                    z <- (+ cost) <$> VUM.unsafeRead v (v0 + j)
                    VUM.unsafeWrite v (v1 + j + 1) (min x (min y z))
                    goj (j + 1) (nb + db)
            when (i < lena) $ do
              VUM.unsafeWrite v v1 (i + 1)
              goj 0 0
              goi (i + 1) (na + da) v1 v0
      gov 0
      goi 0 0 0 v_len
      ld <- VUM.unsafeRead v (lenb + if even lena then 0 else v_len)
      return (ld, lenm)
  where
    lena = T.length a
    lenb = T.length b
    lenm = max lena lenb
{-# INLINE levenshtein_ #-}


stdSeries :: [Int]
stdSeries = [5,10,20,40,80,160]

testData :: Int -> Text
testData n = T.pack . take n . drop (n `mod` 4) . cycle $ ['a'..'z']

btmetric :: NFData a => String -> (Text -> Text -> a) -> Benchmark
btmetric name f = bgroup name (bs <$> stdSeries)
  where
    bs n = env (return (testData n, testData n)) (bench (show n) . nf (uncurry f))


main = defaultMain [btmetric "hamming"     hamming,
                    btmetric "levenshtein" levenshtein]

And the whole Rust module (suggestions for improvements are welcome of course):

#![feature(test, inclusive_range_syntax)]

extern crate test;

use test::Bencher;

fn hamming_distance_strs(s1: &str, s2: &str) -> Option<usize> {
    let mut s1c = s1.chars();
    let mut s2c = s2.chars();
    let mut count = 0;

    loop {
        match (s1c.next(), s2c.next()) {
            (Some(c1), Some(c2)) => { if c1 != c2 { count += 1; } },
            (Some(_), None) | (None, Some(_)) => { return None; },
            (None, None) => { return Some(count); }
        }
    }
}


fn levenshtein_distance_strs(t1: &str, t2: &str) -> usize {
    use std::mem::swap;
    use std::cmp::min;

    const VLEN_MAX: usize = 15;

    fn inner(s1: &str, s2: &str, a: &mut [u32], b: &mut [u32]) -> usize {
        let (mut v0, mut v1) = (a, b);

        for i in 0 .. v0.len() {
            v0[i] = i as u32;
        }

        for (i, c1) in s1.chars().enumerate() {
            v1[0] = i as u32 + 1;

            for (j, c2) in s2.chars().enumerate() {
                let cost = if c1 == c2 { 0 } else { 1 };
                let x = v1[j] + 1;
                let y = v0[j + 1] + 1;
                let z = v0[j] + cost;
                v1[j + 1] = min(x, min(y, z));
            }

            swap(&mut v0, &mut v1);
        }

        v0[v0.len() - 1] as usize
    }

    let (mut s1, mut s2) = (t1, t2);

    if s1.is_empty() { return s2.chars().count(); }
    if s2.is_empty() { return s1.chars().count(); }

    let mut len1 = s1.chars().count();
    let mut len2 = s2.chars().count();

    if len2 > len1 {
        swap(&mut s1, &mut s2);
        swap(&mut len1, &mut len2);
    }

    let vlen = len2 + 1;
    if vlen > VLEN_MAX {
        inner(s1, s2, &mut vec![0; vlen], &mut vec![0; vlen])
    } else {
        inner(s1, s2, &mut [0; VLEN_MAX][.. vlen], &mut [0; VLEN_MAX][.. vlen])
    }
}

#[test]
fn it_works() {
    assert_eq!(hamming_distance_strs("helloo", "bell"), None);
    assert_eq!(hamming_distance_strs("hello", "bello"), Some(2));

    assert_eq!(levenshtein_distance_strs("helloo", "bell"), 3);
    assert_eq!(levenshtein_distance_strs("hello", "bello"), 1);
}

fn test_data(n: usize) -> String {
    (b'a' ... b'z').cycle().skip(n % 4).take(n).map(char::from).collect()
}

macro_rules! generate_test {
    ($test_name:ident, $func_name:ident, $n:expr) => (
        #[bench]
        fn $test_name(b: &mut Bencher) {
            let s = test_data($n);
            b.iter(|| $func_name(&s, &s));
        }
    )
}

generate_test!(hamming_005, hamming_distance_strs, 5);
generate_test!(hamming_010, hamming_distance_strs, 10);
generate_test!(hamming_020, hamming_distance_strs, 20);
generate_test!(hamming_040, hamming_distance_strs, 40);
generate_test!(hamming_080, hamming_distance_strs, 80);
generate_test!(hamming_160, hamming_distance_strs, 160);

generate_test!(levenshtein_005, levenshtein_distance_strs, 5);
generate_test!(levenshtein_010, levenshtein_distance_strs, 10);
generate_test!(levenshtein_020, levenshtein_distance_strs, 20);
generate_test!(levenshtein_040, levenshtein_distance_strs, 40);
generate_test!(levenshtein_080, levenshtein_distance_strs, 80);
generate_test!(levenshtein_160, levenshtein_distance_strs, 160);

// generate_tests!(hamming, 5, 10, 20, 40, 80, 160);
// generate_tests!(levenshtein, 5, 10, 20, 40, 80, 160);

I compile and run that Rust module with:

rustc --test -C opt-level=3 -C target-cpu=native test2.rs
test2 --bench

The timings on my i7 PC are:

test hamming_005     ... bench:          14 ns/iter (+/- 0)
test hamming_010     ... bench:          26 ns/iter (+/- 1)
test hamming_020     ... bench:          55 ns/iter (+/- 3)
test hamming_040     ... bench:         100 ns/iter (+/- 3)
test hamming_080     ... bench:         194 ns/iter (+/- 15)
test hamming_160     ... bench:         396 ns/iter (+/- 28)
test levenshtein_005 ... bench:          85 ns/iter (+/- 14)
test levenshtein_010 ... bench:         259 ns/iter (+/- 30)
test levenshtein_020 ... bench:       1,054 ns/iter (+/- 928)
test levenshtein_040 ... bench:       3,704 ns/iter (+/- 95)
test levenshtein_080 ... bench:      13,822 ns/iter (+/- 512)
test levenshtein_160 ... bench:      53,457 ns/iter (+/- 2,692)

The timings output of the Haskell refined Criterion timing system are:

benchmarking hamming/5
time                 45.61 ns   (45.48 ns .. 45.74 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 45.61 ns   (45.55 ns .. 45.66 ns)
std dev              170.0 ps   (135.2 ps .. 232.6 ps)

benchmarking hamming/10
time                 69.41 ns   (69.29 ns .. 69.54 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 69.19 ns   (69.08 ns .. 69.29 ns)
std dev              361.5 ps   (305.3 ps .. 435.6 ps)

benchmarking hamming/20
time                 105.6 ns   (105.3 ns .. 105.8 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 105.6 ns   (105.4 ns .. 105.7 ns)
std dev              430.1 ps   (334.3 ps .. 601.1 ps)

benchmarking hamming/40
time                 192.8 ns   (191.8 ns .. 193.5 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 190.9 ns   (190.1 ns .. 192.2 ns)
std dev              3.466 ns   (2.512 ns .. 5.250 ns)
variance introduced by outliers: 23% (moderately inflated)

benchmarking hamming/80
time                 314.8 ns   (314.4 ns .. 315.1 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 313.8 ns   (313.5 ns .. 314.1 ns)
std dev              1.167 ns   (954.6 ps .. 1.389 ns)

benchmarking hamming/160
time                 556.3 ns   (554.4 ns .. 557.7 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 551.8 ns   (549.8 ns .. 553.4 ns)
std dev              5.981 ns   (5.032 ns .. 7.317 ns)

benchmarking levenshtein/5
time                 182.1 ns   (181.9 ns .. 182.5 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 181.5 ns   (181.2 ns .. 181.7 ns)
std dev              875.3 ps   (710.2 ps .. 1.116 ns)

benchmarking levenshtein/10
time                 515.9 ns   (514.3 ns .. 517.6 ns)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 514.6 ns   (513.5 ns .. 515.9 ns)
std dev              3.969 ns   (3.253 ns .. 5.409 ns)

benchmarking levenshtein/20
time                 2.021 us   (2.017 us .. 2.025 us)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 2.013 us   (2.010 us .. 2.016 us)
std dev              10.57 ns   (8.704 ns .. 14.00 ns)

benchmarking levenshtein/40
time                 7.039 us   (7.032 us .. 7.045 us)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 7.014 us   (7.005 us .. 7.022 us)
std dev              27.50 ns   (22.99 ns .. 34.74 ns)

benchmarking levenshtein/80
time                 25.62 us   (25.55 us .. 25.67 us)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 25.50 us   (25.43 us .. 25.56 us)
std dev              199.9 ns   (170.6 ns .. 237.2 ns)

benchmarking levenshtein/160
time                 97.19 us   (97.09 us .. 97.30 us)
                     1.000 R²   (1.000 R² .. 1.000 R²)
)
mean                 96.82 us   (96.67 us .. 96.96 us)
std dev              477.3 ns   (373.4 ns .. 588.5 ns)

The comparisons are not very fair because Rust is using UTF-8 and Haskell UTF-16, so for long English strings it handles half the memory. This should reduce the cache misses.

Edit: this benchmark is simplistic because it doesn’t exercise all control paths in the code. The input strings are the same, so their distance is always 0. And they are ASCII chars only.


#2

New discussions on the same topic:

https://two-wrongs.com/on-competing-with-c-using-haskell

That suggests further improvements:

We could imagine comparing pairs of 16-bit words until one of them is a surrogate pair, in which case we let the text library deal with it, and then go back to comparing 16-bit values. It may sound expensive to check for surrogates every iteration, but the idea is that if the string consists mostly of characters that aren’t made up of surrogate pairs, the branch predictor should let the cpu thunder through the condition to the 16-bit value case anyway. And we only pay the price for surrogate pairs when we actually use surrogate pairs. Doing so in fact yields code that performs closer to the raw 16-bit version than the otherwise fastest tail recursion implementation.<

The new discussion thread:


#3

The Haskell code is obviously quite fast, but it’s also not exactly idiomatic. This is pretty neat, and also in line with what I’ve found with my own experience - Haskell is usually about half as fast as the Rust, if you write it in a rather Spartan style.


#4

I’m not sure if I’m not misunderstanding the problem, but that could be fairly easily implemented using iterators. I wonder how this idiomatic version compares. With -C target_cpu=native option, it can even use AVX. And you probably could speed-up this even further for very long arrays by using rayon.

pub fn tmetrics_hamming(a: &[u16], b: &[u16]) -> usize {
    a.iter().zip(b).filter(|&(a, b)| a == b).count()
}

I decided to use [u16] type to match C benchmark.


#5

You can’t use the std library zip because the Hamming function should return some kind of error if the two input strings are of different length.


#6

In that case…

pub fn tmetrics_hamming(a: &[u16], b: &[u16]) -> Option<usize> {
    if a.len() == b.len() {
        Some(a.iter().zip(b).filter(|&(a, b)| a == b).count())
    } else {
        None
    }
}

#7

The problem is set on UTF-16 (or UTF-8 in my code) text, not on arrays. So you can’t use len(), you have to use something like chars().count(), this means you decode the UTF two times. This is not a performance improvement over my version of the code…