Is there a better way write Back Tracking algorithm?

TL;DR : Sometimes we may create many &mut references to a shared status, Is there any solution to save the 8-bit waste of multiple &mut pointer?

minimal example:

#[derive(Debug)]
struct Counter<'a>(&'a mut i32,i32); // &mut counting result, current step length.
impl<'a:'b,'b> Counter<'a>{
    fn forward(&'b mut self,step:i32)->Counter<'b>{
        *self.0+=step;
        Counter(self.0,step)
    }
}
impl Drop for Counter<'_>{
    fn drop(&mut self){
        *self.0-=self.1;
    }
}
fn silly_usage(){
    let mut a=0i32;
    let mut counter0=Counter(&mut a,0);
    {
        let mut counter2=counter0.forward(2);
        println!("counter2 is available in this scope, value is {:?}, while counter0 is not available in this scope",counter2);
        // here, both counter0 and counter1 points to &mut a, there might be an extra 8-byte cost
        // is it possible to eliminate this extra 8-byte cost?
        // LLVM might optimize it here, but LLVM do not always work, since we could add more fields in struct, and thus counter and counter2 might have different value.
        {
            let mut counter5=counter2.forward(3);
            dbg!(counter5);
        }
        dbg!(counter2);
    }
    dbg!(counter0);
}

Real world program (Backtracking) that may generate multiple &mut references to a shared status

the current Back Tracking algorithm I wrote contains a "status", which looks like

fn dfs<const N:usize>(mut status: T<N>) -> T<N> where [(); 2*N+1]:{
    //println!("{status:?}");
    if check::<N>(&status) {println!("{:?}",&status[1..]);return status}
    for op in get_all_ops::<N>(&status){
        let mut legal=true;
        status=forward::<N>(status,op,&mut legal);
        if legal{
            status=dfs::<N>(status);
            if check::<N>(&status) {return status}
            status=back_tracking::<N>(status,op)
        }
    }
    status
}
(whole program)

#![feature(generic_const_exprs)]
type T=[i32;2*N+1];
fn dfs(mut status: T) -> T where [(); 2*N+1]:{
    //println!("{status:?}");
    if check::(&status) {println!("{:?}",&status[1..]);return status}
    for op in get_all_ops::(&status){
        let mut legal=true;
        status=forward::(status,op,&mut legal);
        if legal{
            status=dfs::(status);
            if check::(&status) {return status}
            status=back_tracking::(status,op)
        }
    }
    status
}
#[inline(always)]
fn check(s:&T)->bool where [(); 2*N+1]:{
    s[0]==1
}
fn forward(mut s:T,op:(i32,i32),legal:&mut bool)->T where [(); 2*N+1]:{
    //print!("op={op:?}  ");
    if s[op.1 as usize] + s[(1+op.0+op.1) as usize]==0 /*should delete while scanning for N=4k+1 and N=4k+2*/&& (1+op.0+op.1) != 2*N as i32 -1{
        s[op.1 as usize]=op.0;
        s[(1+op.0+op.1) as usize]=op.0;
        s[0]-=1;
    } else {*legal=false}
    //println!("fw:{s:?},legal={legal}");
    s
}
fn back_tracking(mut s:T,op:(i32,i32))->T where [(); 2*N+1]:{
    s[op.1 as usize]=0;
    s[(1+op.0+op.1) as usize]=0;
    s[0]+=1;
    s
}
fn get_all_ops(s:&T)->impl IntoIterator where [(); 2*N+1]:{
    let a=s[0];
    (1..2*N as i32-a).map(move |x|(a,x))
}
fn test() where [(); 2*N+1]:{
    let mut s:T=std::array::from_fn(|x|((x as i32)-2*(N as i32)).max(0));
    s[0]=N as i32;
    dfs::(s);
}
fn main(){
//    test::<3>();
//    test::<4>();
    test::<5>();
    test::<6>();
//    test::<7>();
//    test::<8>();
    test::<9>();
    test::<10>();
//    test::<11>();
//    test::<12>();
    test::<13>();
    test::<14>();
//    test::<15>();
//    test::<16>();
    test::<17>();
    test::<18>();
//    test::<19>();
//    test::<20>();
    test::<21>();
    test::<22>();
//    test::<23>();
//    test::<24>();
    test::<25>();
    test::<26>();
//    test::<27>();
//    test::<28>();
    test::<29>();
    test::<30>();
}

It looks silly to manually wrote forward and back_tracking functions, what's more, if I want to wrote multiple dfs algorighms (in this case, 2 programs, one for n=4k and n=4k+3 to search the array have exactly 2 of 1..n, where there is exactly i number between two is. and the second algorithm to relax the restriction since there is no solution for n=4k+1 and n=4k+2)

Thus, I decided to wrote a struct-based Back Tracking algorithm, use the ownership of rust, to automatically perform forward and back tracking.

#![feature(generic_const_exprs)]
struct Test<
    'a,
    C,
    O: Copy,
    T: FnOnce(&mut C, O) -> bool,
    U: FnMut(&mut C, O),
    V: FnOnce(&C) -> bool,
    W: FnOnce(&C) -> X,
    X: IntoIterator<Item = O>,
> {
    status: &'a mut C,
    forward: T,
    back_tracking: U,
    check: V,
    get_all_ops: W,
    op: O,
}
impl<
        'a,
        C: std::fmt::Debug,
        O: Copy,
        T: Copy + FnOnce(&mut C, O) -> bool,
        U: Copy + FnMut(&mut C, O),
        V: Copy + FnOnce(&C) -> bool,
        // W: FnOnce(&C) -> X,
        // move occurs because `self.get_all_ops` has type `W`, which does not implement the `Copy` trait
        W: Copy + FnOnce(&C) -> X,
        X: IntoIterator<Item = O>,
    > Test<'a, C, O, T, U, V, W, X>
{
    fn new(
        status: &'a mut C,
        forward: T,
        back_tracking: U,
        check: V,
        get_all_ops: W,
        op: O,
    ) -> Self {
        Self {
            status,
            forward,
            back_tracking,
            check,
            get_all_ops,
            op,
        }
    }
    fn dfs<'b>(&'b mut self)
    where
        'a: 'b,
    {
        if (self.check)(self.status) {
            println!("{:?}", self.status);
            return;
        }
        for op in (self.get_all_ops)(self.status) {
            if let Some(mut next) = self.forward(op) {
                next.op = op;
                next.dfs();
            }
        }
    }
    fn dfs_early_stop<'b>(&'b mut self) -> bool
    where
        'a: 'b,
    {
        if (self.check)(self.status) {
            println!("{:?}", self.status);
            return true;
        }
        for op in (self.get_all_ops)(self.status) {
            if let Some(mut next) = self.forward(op) {
                next.op = op;
                if next.dfs_early_stop() {
                    return true;
                }
            }
        }
        false
    }
    fn forward<'b: 'c, 'c>(&'b mut self, op: O) -> Option<Test<'b, C, O, T, U, V, W, X>>
    where
        'b: 'c,
    {
        if (self.forward)(self.status, op) {
            Some(Test {
                status: self.status,
                forward: self.forward,
                back_tracking: self.back_tracking,
                check: self.check,
                get_all_ops: self.get_all_ops,
                op,
            })
        } else {
            None
        }
    }
}
impl<
        'a,
        C,
        O: Copy,
        T: FnOnce(&mut C, O) -> bool,
        U: FnMut(&mut C, O),
        V: FnOnce(&C) -> bool,
        W: FnOnce(&C) -> X,
        X: IntoIterator<Item = O>,
    > Drop for Test<'a, C, O, T, U, V, W, X>
{
    fn drop(&mut self) {
        (self.back_tracking)(self.status, self.op)
    }
}

fn test<const N: usize>()
where
    [(); 2 * N + 1]:,
{
    let mut s = [0; 2 * N + 1];
    s[0] = N as i32;
    let mut a = Test::new(
        &mut s,
        |s: &mut [i32; 2 * N + 1], op: (i32, i32)| {
            if s[op.1 as usize] + s[(1 + op.0 + op.1) as usize] == 0
            // /*should delete while scanning for N=4k+1 and N=4k+2*/&& (1+op.0+op.1) != 2*N as i32 -1
            {
                s[op.1 as usize] = op.0;
                s[(1 + op.0 + op.1) as usize] = op.0;
                s[0] -= 1;
                true
            } else {
                false
            }
        },
        |s: &mut [i32; 2 * N + 1], op: (i32, i32)| {
            s[op.1 as usize] = 0;
            s[(1 + op.0 + op.1) as usize] = 0;
            s[0] += 1
        },
        |s: &[i32; 2 * N + 1]| s[0] == 0,
        |s: &[i32; 2 * N + 1]| {
            let a = s[0];
            (1..2 * N as i32 - a).map(move |x| (a, x))
        },
        (1, 0),
    );
    a.dfs();
    a.dfs_early_stop();
    println!("{}", std::mem::size_of_val(&a));
}

fn main() {
    test::<3>();
    test::<4>();
    test::<7>();
    test::<8>();
}

As you can see, the size_of_val(&a) is 8 bytes larger than expected since all of the Test struct store the same references of the common status.

Is there any method to avoid this waste?

further, why we have to mark a ZST (all the closures) Copy? why it isn't derived automatically?

what's more, could we stop some drop function from executing(e.g., we want to save the status for further explores) while drop the current Test struct to prevent possible memory leak?

Never at the same time though. That's kind of core to what &mut means.
Given that, there's no real waste to speak of.

While your example code already doesn't compile, if it did, as soon as counter2 is created (which should use a let btw, or else be assigned outside of the block it's assigned in) counter would be invalidated. Therefore any code trying to use that would then fail to compile.

I suspect you have the wrong model of unique/mutable borrows.
In particular, they can't be used as long lived pointers (contrary to raw Rust pointers and C/C++ pointers), and you can't have more than 1 unique borrow to something at a time.
Trying to hold on to a unique borrow after its immediate use is generally a good way to inflict pain on yourself.

Because Copy semantics are not always desirable. Silly example: if you have a token of some kind (which may or may not be ZST) it should be neither Copy nor Clone, because tokens that are easily reproduced that way can't really prove anything, and thus are useless.

That's entirely up to you. If you model the code such that such a Test value contains all state, then dropping it will release the memory it consumes.

Stopping a drop fn from executing isn't the solution here though, all it would do is at best cause memory leaks and at worst cause unsoundness or UB.

The real solution is just extracting the state you want to keep from the Test value before dropping it.

2 Likes

sorry for not checking the minimal program..

Here, we could execute .forward many times, holding many counter at thr same time.
although only 1 counter is available at a specific scope, all the 3 counter are stored in memory, thus there could be a 16 bytes waste.

#[derive(Debug)]
struct Counter<'a>(&'a mut i32,i32); // &mut counting result, current step length.
impl<'a:'b,'b> Counter<'a>{
    fn forward(&'b mut self,step:i32)->Counter<'b>{
        *self.0+=step;
        Counter(self.0,step)
    }
}
impl Drop for Counter<'_>{
    fn drop(&mut self){
        *self.0-=self.1;
    }
}
fn silly_usage(){
    let mut a=0i32;
    let mut counter0=Counter(&mut a,0);
    {
        let mut counter2=counter0.forward(2);
        println!("counter2 is available in this scope, value is {:?}, while counter0 is not available in this scope",counter2);
        // here, both counter0 and counter1 points to &mut a, there might be an extra 8-byte cost
        // is it possible to eliminate this extra 8-byte cost?
        // LLVM might optimize it here, but LLVM do not always work, since we could add more fields in struct, and thus counter and counter2 might have different value.
        {
            let mut counter5=counter2.forward(3);
            dbg!(counter5);
        }
        dbg!(counter2);
    }
    dbg!(counter0);
}

fn main(){
    silly_usage()
}