TL;DR : Sometimes we may create many &mut references to a shared status, Is there any solution to save the 8-bit waste of multiple &mut
pointer?
minimal example:
#[derive(Debug)]
struct Counter<'a>(&'a mut i32,i32); // &mut counting result, current step length.
impl<'a:'b,'b> Counter<'a>{
fn forward(&'b mut self,step:i32)->Counter<'b>{
*self.0+=step;
Counter(self.0,step)
}
}
impl Drop for Counter<'_>{
fn drop(&mut self){
*self.0-=self.1;
}
}
fn silly_usage(){
let mut a=0i32;
let mut counter0=Counter(&mut a,0);
{
let mut counter2=counter0.forward(2);
println!("counter2 is available in this scope, value is {:?}, while counter0 is not available in this scope",counter2);
// here, both counter0 and counter1 points to &mut a, there might be an extra 8-byte cost
// is it possible to eliminate this extra 8-byte cost?
// LLVM might optimize it here, but LLVM do not always work, since we could add more fields in struct, and thus counter and counter2 might have different value.
{
let mut counter5=counter2.forward(3);
dbg!(counter5);
}
dbg!(counter2);
}
dbg!(counter0);
}
Real world program (Backtracking) that may generate multiple &mut references to a shared status
the current Back Tracking algorithm I wrote contains a "status", which looks like
fn dfs<const N:usize>(mut status: T<N>) -> T<N> where [(); 2*N+1]:{
//println!("{status:?}");
if check::<N>(&status) {println!("{:?}",&status[1..]);return status}
for op in get_all_ops::<N>(&status){
let mut legal=true;
status=forward::<N>(status,op,&mut legal);
if legal{
status=dfs::<N>(status);
if check::<N>(&status) {return status}
status=back_tracking::<N>(status,op)
}
}
status
}
(whole program)
#![feature(generic_const_exprs)]
type T=[i32;2*N+1];
fn dfs(mut status: T) -> T where [(); 2*N+1]:{
//println!("{status:?}");
if check::(&status) {println!("{:?}",&status[1..]);return status}
for op in get_all_ops::(&status){
let mut legal=true;
status=forward::(status,op,&mut legal);
if legal{
status=dfs::(status);
if check::(&status) {return status}
status=back_tracking::(status,op)
}
}
status
}
#[inline(always)]
fn check(s:&T)->bool where [(); 2*N+1]:{
s[0]==1
}
fn forward(mut s:T,op:(i32,i32),legal:&mut bool)->T where [(); 2*N+1]:{
//print!("op={op:?} ");
if s[op.1 as usize] + s[(1+op.0+op.1) as usize]==0 /*should delete while scanning for N=4k+1 and N=4k+2*/&& (1+op.0+op.1) != 2*N as i32 -1{
s[op.1 as usize]=op.0;
s[(1+op.0+op.1) as usize]=op.0;
s[0]-=1;
} else {*legal=false}
//println!("fw:{s:?},legal={legal}");
s
}
fn back_tracking(mut s:T,op:(i32,i32))->T where [(); 2*N+1]:{
s[op.1 as usize]=0;
s[(1+op.0+op.1) as usize]=0;
s[0]+=1;
s
}
fn get_all_ops(s:&T)->impl IntoIterator where [(); 2*N+1]:{
let a=s[0];
(1..2*N as i32-a).map(move |x|(a,x))
}
fn test() where [(); 2*N+1]:{
let mut s:T=std::array::from_fn(|x|((x as i32)-2*(N as i32)).max(0));
s[0]=N as i32;
dfs::(s);
}
fn main(){
// test::<3>();
// test::<4>();
test::<5>();
test::<6>();
// test::<7>();
// test::<8>();
test::<9>();
test::<10>();
// test::<11>();
// test::<12>();
test::<13>();
test::<14>();
// test::<15>();
// test::<16>();
test::<17>();
test::<18>();
// test::<19>();
// test::<20>();
test::<21>();
test::<22>();
// test::<23>();
// test::<24>();
test::<25>();
test::<26>();
// test::<27>();
// test::<28>();
test::<29>();
test::<30>();
}
It looks silly to manually wrote forward and back_tracking functions, what's more, if I want to wrote multiple dfs algorighms (in this case, 2 programs, one for n=4k and n=4k+3 to search the array have exactly 2 of 1..n, where there is exactly i number between two i
s. and the second algorithm to relax the restriction since there is no solution for n=4k+1 and n=4k+2)
Thus, I decided to wrote a struct-based Back Tracking algorithm, use the ownership of rust, to automatically perform forward and back tracking.
#![feature(generic_const_exprs)]
struct Test<
'a,
C,
O: Copy,
T: FnOnce(&mut C, O) -> bool,
U: FnMut(&mut C, O),
V: FnOnce(&C) -> bool,
W: FnOnce(&C) -> X,
X: IntoIterator<Item = O>,
> {
status: &'a mut C,
forward: T,
back_tracking: U,
check: V,
get_all_ops: W,
op: O,
}
impl<
'a,
C: std::fmt::Debug,
O: Copy,
T: Copy + FnOnce(&mut C, O) -> bool,
U: Copy + FnMut(&mut C, O),
V: Copy + FnOnce(&C) -> bool,
// W: FnOnce(&C) -> X,
// move occurs because `self.get_all_ops` has type `W`, which does not implement the `Copy` trait
W: Copy + FnOnce(&C) -> X,
X: IntoIterator<Item = O>,
> Test<'a, C, O, T, U, V, W, X>
{
fn new(
status: &'a mut C,
forward: T,
back_tracking: U,
check: V,
get_all_ops: W,
op: O,
) -> Self {
Self {
status,
forward,
back_tracking,
check,
get_all_ops,
op,
}
}
fn dfs<'b>(&'b mut self)
where
'a: 'b,
{
if (self.check)(self.status) {
println!("{:?}", self.status);
return;
}
for op in (self.get_all_ops)(self.status) {
if let Some(mut next) = self.forward(op) {
next.op = op;
next.dfs();
}
}
}
fn dfs_early_stop<'b>(&'b mut self) -> bool
where
'a: 'b,
{
if (self.check)(self.status) {
println!("{:?}", self.status);
return true;
}
for op in (self.get_all_ops)(self.status) {
if let Some(mut next) = self.forward(op) {
next.op = op;
if next.dfs_early_stop() {
return true;
}
}
}
false
}
fn forward<'b: 'c, 'c>(&'b mut self, op: O) -> Option<Test<'b, C, O, T, U, V, W, X>>
where
'b: 'c,
{
if (self.forward)(self.status, op) {
Some(Test {
status: self.status,
forward: self.forward,
back_tracking: self.back_tracking,
check: self.check,
get_all_ops: self.get_all_ops,
op,
})
} else {
None
}
}
}
impl<
'a,
C,
O: Copy,
T: FnOnce(&mut C, O) -> bool,
U: FnMut(&mut C, O),
V: FnOnce(&C) -> bool,
W: FnOnce(&C) -> X,
X: IntoIterator<Item = O>,
> Drop for Test<'a, C, O, T, U, V, W, X>
{
fn drop(&mut self) {
(self.back_tracking)(self.status, self.op)
}
}
fn test<const N: usize>()
where
[(); 2 * N + 1]:,
{
let mut s = [0; 2 * N + 1];
s[0] = N as i32;
let mut a = Test::new(
&mut s,
|s: &mut [i32; 2 * N + 1], op: (i32, i32)| {
if s[op.1 as usize] + s[(1 + op.0 + op.1) as usize] == 0
// /*should delete while scanning for N=4k+1 and N=4k+2*/&& (1+op.0+op.1) != 2*N as i32 -1
{
s[op.1 as usize] = op.0;
s[(1 + op.0 + op.1) as usize] = op.0;
s[0] -= 1;
true
} else {
false
}
},
|s: &mut [i32; 2 * N + 1], op: (i32, i32)| {
s[op.1 as usize] = 0;
s[(1 + op.0 + op.1) as usize] = 0;
s[0] += 1
},
|s: &[i32; 2 * N + 1]| s[0] == 0,
|s: &[i32; 2 * N + 1]| {
let a = s[0];
(1..2 * N as i32 - a).map(move |x| (a, x))
},
(1, 0),
);
a.dfs();
a.dfs_early_stop();
println!("{}", std::mem::size_of_val(&a));
}
fn main() {
test::<3>();
test::<4>();
test::<7>();
test::<8>();
}
As you can see, the size_of_val(&a)
is 8 bytes larger than expected since all of the Test struct store the same references of the common status.
Is there any method to avoid this waste?
further, why we have to mark a ZST (all the closures) Copy? why it isn't derived automatically?
what's more, could we stop some drop function from executing(e.g., we want to save the status for further explores) while drop the current Test
struct to prevent possible memory leak?