Is there any convenient crate could run parallel SIMD code easily?

I try to wrote a Rust program to calculate lots of margainal likelihood, which might execute the following code repeatly:

fn demo(data:(i32,i32))->i32{ // this is actually the mvbvn function which could be called many time and worth a parallelized version.
    if demo_condition1(data) {
        body(data.0)+body(data.1)-
        body(data.0+data.1)-body(data.0-data.1)
    } else if demo_condition2(data) {
        body(data.0)+body(data.1)
    } else {
        body(data.0+data.1)
    }
}
fn body(data:i32)->i32{ // mvbvu function, which could be 
    if condition1(data) {
        calc1(data)
    } else if condition2(data) {
        calc2(data)
    } else/* if condition3(data) */{
        calc3(data)
    }
}

All of the calc function could be rewrite into SIMD code, and the condition is quite easy to evaluate.

For now, it is not very difficult for me to using Rayon which provide the code runs in parallel, but notice that such program also worth a SIMD version speedup, is there any crate to achieve it?

If there is no such crate, what I should do and should notice to wrote such a crate?

Full code There is 3 functions here, the working function is `parallel_batch_likelihood`, which calls `mvbvn` for many times.

mvbvn is a wrapper to call mvbvu for 1-4 times

and mvbvu could be rewritten into SIMD code.

    /// Parallel version of `batch_likelihood` function
    /// see `batch_likelihood` function for more informations.
    /// @export
    #[extendr]
        fn parallel_batch_likelihood(
        lower: &[f64],
        upper: &[f64],
        corr: &[f64],
        i: &[i32],
        j: &[i32],
    ) -> Vec {
        (corr, i, j)
            .into_par_iter()
            .map(|(&c, &i, &j)| {
                unsafe{
                    std::intrinsics::assume((i as usize)<lower.len());
                    std::intrinsics::assume((j as usize)<lower.len());
                    std::intrinsics::assume((i as usize)<upper.len());
                    std::intrinsics::assume((j as usize) f64 {
        if li >= ui || lj >= uj {
            0.0
        } else {
            if li.abs() > ui.abs() {
                (li, ui, corr) = (-ui, -li, -corr)
            }
            if lj.abs() > uj.abs() {
                (lj, uj, corr) = (-uj, -lj, -corr)
            }
            //         if li==f64::NEG_INFINITY {MVPHI(uj)-MVPHI(lj)}
            mvbvu(li, lj, corr)
                - if uj != f64::INFINITY {
                    mvbvu(li, uj, corr)
                } else {
                    0.0
                }
                - if ui != f64::INFINITY {
                    mvbvu(ui, lj, corr)
                } else {
                    0.0
                }
                + if ui + uj != f64::INFINITY {
                    mvbvu(ui, uj, corr)
                } else {
                    0.0
                }
        }
    }
    fn mvbvu(sh: f64, sk: f64, r: f64) -> f64 {
        //      DOUBLE PRECISION BVN, SH, SK, R, ZERO, TWOPI
        //     let mut bvn=0f64;
        //      INTEGER I, LG, NG
        let lg;
        let ng;
        //      PARAMETER ( ZERO = 0, TWOPI = 6.283185307179586D0 )
        const TWOPI: f64 = 6.283185307179586;
        //      DOUBLE PRECISION X(10,3), W(10,3), AS, A, B, C, D, RS, XS
        //      DOUBLE PRECISION MVPHI, SN, ASR, H, K, BS, HS, HK
        //      SAVE X, W
        // *     Gauss Legendre Points and Weights, N =  6
        //       DATA ( W(I,1), X(I,1), I = 1, 3 ) /
        //      *  0.1713244923791705D+00,-0.9324695142031522D+00,
        //      *  0.3607615730481384D+00,-0.6612093864662647D+00,
        //      *  0.4679139345726904D+00,-0.2386191860831970D+00/
        // *     Gauss Legendre Points and Weights, N = 12
        //       DATA ( W(I,2), X(I,2), I = 1, 6 ) /
        //      *  0.4717533638651177D-01,-0.9815606342467191D+00,
        //      *  0.1069393259953183D+00,-0.9041172563704750D+00,
        //      *  0.1600783285433464D+00,-0.7699026741943050D+00,
        //      *  0.2031674267230659D+00,-0.5873179542866171D+00,
        //      *  0.2334925365383547D+00,-0.3678314989981802D+00,
        //      *  0.2491470458134029D+00,-0.1252334085114692D+00/
        // *     Gauss Legendre Points and Weights, N = 20
        //       DATA ( W(I,3), X(I,3), I = 1, 10 ) /
        //      *  0.1761400713915212D-01,-0.9931285991850949D+00,
        //      *  0.4060142980038694D-01,-0.9639719272779138D+00,
        //      *  0.6267204833410906D-01,-0.9122344282513259D+00,
        //      *  0.8327674157670475D-01,-0.8391169718222188D+00,
        //      *  0.1019301198172404D+00,-0.7463319064601508D+00,
        //      *  0.1181945319615184D+00,-0.6360536807265150D+00,
        //      *  0.1316886384491766D+00,-0.5108670019508271D+00,
        //      *  0.1420961093183821D+00,-0.3737060887154196D+00,
        //      *  0.1491729864726037D+00,-0.2277858511416451D+00,
        //      *  0.1527533871307259D+00,-0.7652652113349733D-01/

        const W: [[f64; 10]; 3] = [
            [
                0.1713244923791705,
                0.3607615730481384,
                0.4679139345726904,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            ],
            [
                0.04717533638651177,
                0.1069393259953183,
                0.1600783285433464,
                0.2031674267230659,
                0.2334925365383547,
                0.2491470458134029,
                0.0,
                0.0,
                0.0,
                0.0,
            ],
            [
                0.01761400713915212,
                0.04060142980038694,
                0.06267204833410905,
                0.08327674157670475,
                0.1019301198172404,
                0.1181945319615184,
                0.1316886384491766,
                0.1420961093183821,
                0.1491729864726037,
                0.1527533871307259,
            ],
        ]; // python script (ignore the leading //): [[float(i[0]) for i in i]+[0.0]*(10-len(i)) for i in [[i.strip().replace('D','e').split(',') for i in i.split('*')[1:] for i in i.strip().split('\n')] for i in "\n".join([a for a in a.split('\n') if 'D+' in a]).split('/')[:3]]]
        const X: [[f64; 10]; 3] = [
            [
                -0.9324695142031522,
                -0.6612093864662647,
                -0.238619186083197,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
                0.0,
            ],
            [
                -0.9815606342467191,
                -0.904117256370475,
                -0.769902674194305,
                -0.5873179542866171,
                -0.3678314989981802,
                -0.1252334085114692,
                0.0,
                0.0,
                0.0,
                0.0,
            ],
            [
                -0.9931285991850949,
                -0.9639719272779138,
                -0.912234428251326,
                -0.8391169718222188,
                -0.7463319064601508,
                -0.636053680726515,
                -0.5108670019508271,
                -0.3737060887154196,
                -0.2277858511416451,
                -0.07652652113349732,
            ],
        ]; // python script (ignore the leading //): [[float(i[1]) for i in i]+[0.0]*(10-len(i)) for i in [[i.strip().replace('D','e').split(',') for i in i.split('*')[1:] for i in i.strip().split('\n')] for i in "\n".join([a for a in a.split('\n') if 'D+' in a]).split('/')[:3]]]
        //       IF ( ABS(R) .LT. 0.3 ) THEN
        //          NG = 1
        //          LG = 3
        //       ELSE IF ( ABS(R) .LT. 0.75 ) THEN
        //          NG = 2
        //          LG = 6
        //       ELSE
        //          NG = 3
        //          LG = 10
        //       ENDIF
        if r.abs() < 0.3 {
            ng = 0;
            lg = 3;
        } else if r.abs() < 0.75 {
            ng = 1;
            lg = 6;
        } else {
            ng = 2;
            lg = 10;
        }
        //       H = SH
        //       K = SK
        let h = sh;
        let mut k = sk;
        //       HK = H*K
        let mut hk = h * k;
        //       BVN = 0
        //       IF ( ABS(R) .LT. 0.925 ) THEN
        if r.abs() < 0.925 {
            //          HS = ( H*H + K*K )/2
            let hs = (h * h + k * k) / 2.0;
            //          ASR = ASIN(R)
            let asr = r.asin();
            //          DO I = 1, LG
            //         for i in 0..lg{
            // //             SN = SIN(ASR*( X(I,NG)+1 )/2)
            //             sn=(asr*(( X[ng][i]+1f64 )/2f64)).sin();
            // //             BVN = BVN + W(I,NG)*EXP( ( SN*HK - HS )/( 1 - SN*SN ) )
            //             bvn+=W[ng][i]*((sn*hk-hs)/(1-sn*sn)).exp()
            // //             SN = SIN(ASR*(-X(I,NG)+1 )/2)
            // //             BVN = BVN + W(I,NG)*EXP( ( SN*HK - HS )/( 1 - SN*SN ) )
            // //          END DO
            //         }
            //          BVN = BVN*ASR/(2*TWOPI) + MVPHI(-H)*MVPHI(-K)
            (0..lg)
                .map(|i| {
                    let sn1 = (asr * ((X[ng][i] + 1.0) / 2.0)).sin();
                    let sn2 = (asr * ((-X[ng][i] + 1.0) / 2.0)).sin();
                    W[ng][i]
                        * (((sn1 * hk - hs) / (1.0 - sn1 * sn1)).exp()
                            + ((sn2 * hk - hs) / (1.0 - sn2 * sn2)).exp())
                })
                .sum::()
                * asr
                / (TWOPI * 2.0)
                + mvphi(-h) * mvphi(-k)
        } else {
            //       ELSE
            //          IF ( R .LT. 0 ) THEN
            //             K = -K
            //             HK = -HK
            //          ENDIF
            if r < 0.0 {
                k = -k;
                hk = -hk;
            }
            //          IF ( ABS(R) .LT. 1 ) THEN
            let bvn = if r.abs() < 1.0 {
                //             AS = ( 1 - R )*( 1 + R )
                let r#as = (1.0 - r) * (1.0 + r);
                //             A = SQRT(AS)
                let a = r#as.sqrt();
                //             BS = ( H - K )**2
                let bs = (h - k) * (h - k);
                //             C = ( 4 - HK )/8
                //             D = ( 12 - HK )/16
                let c = (4.0 - hk) / 8.0;
                let d = (12.0 - hk) / 16.0;
                //             BVN = A*EXP( -(BS/AS + HK)/2 )
                //      +             *( 1 - C*(BS - AS)*(1 - D*BS/5)/3 + C*D*AS*AS/5 )
                let mut bvn = a
                    * (-(bs / r#as + hk) / 2.0).exp()
                    * (1.0 - c * (bs - r#as) * (1.0 - d * bs / 5.0) / 3.0 + c * d * r#as * r#as / 5.0);
                //             IF ( HK .GT. -160 ) THEN
                if hk > -160.0 {
                    //                B = SQRT(BS)
                    let b = bs.sqrt();
                    //                BVN = BVN - EXP(-HK/2)*SQRT(TWOPI)*MVPHI(-B/A)*B
                    //      +                    *( 1 - C*BS*( 1 - D*BS/5 )/3 )
                    bvn -= (-hk / 2.0).exp()
                        * TWOPI.sqrt()
                        * mvphi(-b / a)
                        * b
                        * (1.0 - c * bs * (1.0 - d * bs / 5.0) / 3.0);
                    //             ENDIF
                }
                //             A = A/2
                let a = a / 2.0;
                //             DO I = 1, LG
                //                XS = ( A*(X(I,NG)+1) )**2
                //                RS = SQRT( 1 - XS )
                //                BVN = BVN + A*W(I,NG)*
                //      +              ( EXP( -BS/(2*XS) - HK/(1+RS) )/RS
                //      +              - EXP( -(BS/XS+HK)/2 )*( 1 + C*XS*( 1 + D*XS ) ) )
                //                XS = AS*(-X(I,NG)+1)**2/4
                //                RS = SQRT( 1 - XS )
                //                BVN = BVN + A*W(I,NG)*EXP( -(BS/XS + HK)/2 )
                //      +                    *( EXP( -HK*(1-RS)/(2*(1+RS)) )/RS
                //      +                       - ( 1 + C*XS*( 1 + D*XS ) ) )
                //             END DO
                -(bvn
                    + (0..lg)
                        .map(|i| {
                            let xs = (a * (X[ng][i as usize] + 1.0)) * (a * (X[ng][i as usize] + 1.0));
                            let rs = (1.0 - xs).sqrt();
                            let bvn = a
                                * W[ng][i as usize]
                                * ((-bs / (2.0 * xs) - hk / (1.0 + rs)).exp() / rs
                                    - (-(bs / xs + hk) / 2.0).exp() * (1.0 + c * xs * (1.0 + d * xs)));
                            let xs =
                                r#as * (-X[ng][i as usize] + 1.0) * (-X[ng][i as usize] + 1.0) / 4.0;
                            let rs = (1.0 - xs).sqrt();
                            bvn + a
                                * W[ng][i as usize]
                                * (-(bs / xs + hk) / 2.0).exp()
                                * ((-hk * (1.0 - rs) / (2.0 * (1.0 + rs))).exp() / rs
                                    - (1.0 + c * xs * (1.0 + d * xs)))
                        })
                        .sum::())
                    / TWOPI
            //             BVN = -BVN/TWOPI

            //          ENDIF
            } else {
                0.0
            };
            //          IF ( R .GT. 0 ) BVN =  BVN + MVPHI( -MAX( H, K ) )
            //          IF ( R .LT. 0 ) BVN = -BVN + MAX( ZERO, MVPHI(-H) - MVPHI(-K) )
            if r > 0.0 {
                bvn + mvphi(-(h.max(k))) // mvphi could be calculated by `erf`
            } else {
                -bvn + 0f64.max(mvphi(-h) - mvphi(-k))
            }
            //       ENDIF
        }
        //       MVBVU = BVN
        //       END
    }

rust-lang/packed_simd: Portable Packed SIMD Vectors for Rust standard library

It works well combined with rayon.

not so perfectly here.

if you want to do some simple calculation, such as a[i]+b[i]+c[i], this crate works fine

But currently, I might calculate:

b[0]+c[0] // due to a[0]==0
a[1]+b[1] // due to a[1]>0
a[2]+c[2] // due to a[2]<0

In this case, there is a lot of work to do with packed SIMD vectors.

I know packed_simd crate is helpful for rewritting calc* function (or mvbvu), but wrote those function is not enough. how to dispatch them and when to call them is also the questions.

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.