Is there anyone who can review my unsafe usage in mergesort module?
I think my unsafe usage is sound but there might be something wrong that I've missed.
use std::mem::{self, MaybeUninit, ManuallyDrop};
pub fn sort<T: Ord>(s: &mut [T]) {
if s.len() > 0 {
let mut tmp: Vec<ManuallyDrop<T>> = Vec::with_capacity(s.len());
mergesort_recursive(s, &mut tmp, 0, s.len() - 1);
}
}
fn mergesort_recursive<T: Ord>(s: &mut [T], tmp: &mut Vec<ManuallyDrop<T>>, start: usize, end: usize) {
if end.saturating_sub(start) <= 0 {
return;
}
let mid = (start + end) / 2;
mergesort_recursive(s, tmp, start, mid);
mergesort_recursive(s, tmp, mid + 1, end);
merge(s, tmp, start, mid, end);
}
fn merge<T: Ord>(s: &mut [T], tmp: &mut Vec<ManuallyDrop<T>>, start: usize, mid: usize, end: usize) {
assert!(start <= mid && mid + 1 <= end);
tmp.clear();
//It is safe to transmute &mut [T] to &mut [MaybeUninit<T>] because MaybeUninit is transparent.
let s: &mut [MaybeUninit<T>] = unsafe { mem::transmute(s) };
let mut i = start;
let mut j = mid + 1;
while i <= mid && j <= end {
//It may panic while comparing.
//It is safe to painc while comparing. Double drop should't happen since copied element is ManuallyDrop
if unsafe { s[i].assume_init_ref() < s[j].assume_init_ref() } {
tmp.push(unsafe { ManuallyDrop::new(s[i].assume_init_read()) });
i += 1;
} else {
tmp.push(unsafe { ManuallyDrop::new(s[j].assume_init_read()) });
j += 1;
}
}
while i <= mid {
tmp.push(unsafe { ManuallyDrop::new(s[i].assume_init_read()) });
i += 1;
}
while j <= end {
tmp.push(unsafe { ManuallyDrop::new(s[j].assume_init_read()) });
j += 1;
}
//check if length of tmp is equal to input array for safety.
assert_eq!(tmp.len(), end - start + 1);
//move tmp array to original array
for i in 0..tmp.len() {
//write will overwrite previous value without dropping.
//it should not panic while copying.
s[end - i].write(ManuallyDrop::into_inner(tmp.pop().unwrap()));
}
}
It is safe to transmute &mut [T] to &mut [MaybeUninit<T>]
No it isn't. It is generally not safe to transmute references, and even if it technically might be in some cases, it's still bad style, because it is very brittle (it can break without you noticing it due to seemingly irrelevant changes). Cast the pointers instead.
Apart from that, I don't see anything else obviously wrong. There are a lot of conversions back and forth, between dropping and non-dropping types, so it's still kind of confusing.
Thanks for reviewing my code!
I'v removed transmute operation and cleaned up my code so there is less conversions back and forth.
In any case, I think it is best to avoid transmutation if possible.
Here is my fixed code.
use std::mem::ManuallyDrop;
use std::ptr;
pub fn sort<T: Ord>(s: &mut [T]) {
if s.len() > 0 {
let mut tmp: Vec<ManuallyDrop<T>> = Vec::with_capacity(s.len());
mergesort_recursive(s, &mut tmp, 0, s.len() - 1);
}
}
fn mergesort_recursive<T: Ord>(
s: &mut [T],
tmp: &mut Vec<ManuallyDrop<T>>,
start: usize,
end: usize,
) {
if end.saturating_sub(start) <= 0 {
return;
}
let mid = (start + end) / 2;
mergesort_recursive(s, tmp, start, mid);
mergesort_recursive(s, tmp, mid + 1, end);
merge(s, tmp, start, mid, end);
}
fn merge<T: Ord>(
s: &mut [T],
tmp: &mut Vec<ManuallyDrop<T>>,
start: usize,
mid: usize,
end: usize,
) {
assert!(start <= mid && mid + 1 <= end);
tmp.clear();
let mut i = start;
let mut j = mid + 1;
while i <= mid && j <= end {
//It may panic while comparing.
//It is safe to painc while comparing. Double drop should't happen since copied element is ManuallyDrop
if s[i] < s[j] {
tmp.push(unsafe { ManuallyDrop::new(ptr::read(&s[i] as *const T)) });
i += 1;
} else {
tmp.push(unsafe { ManuallyDrop::new(ptr::read(&s[j] as *const T)) });
j += 1;
}
}
while i <= mid {
tmp.push(unsafe { ManuallyDrop::new(ptr::read(&s[i] as *const T)) });
i += 1;
}
while j <= end {
tmp.push(unsafe { ManuallyDrop::new(ptr::read(&s[j] as *const T)) });
j += 1;
}
//check if length of tmp is equal to input array for safety.
assert_eq!(tmp.len(), end - start + 1);
//move tmp array to original array
for (i, e) in tmp.iter().enumerate() {
//write will overwrite previous value without dropping.
//it should not panic while copying.
unsafe {
s.as_mut_ptr()
.add(start + i)
.write(ManuallyDrop::into_inner(ptr::read(e as *const _)));
}
}
}
If you want to more conveniently (i.e. without the need for the ptr::read(e as *const _)) get owned access to the values in tmp (and without losing the buffer as something like mem::take plus Vec::into_iter would do), consider using tmp.drain(..).
Edit: Acutally… perhaps even better would be to use copy_nonoverlapping in std::ptr - Rust insted of a loop for the whole “move tmp array to original array” operation.
And I just noticed, with the ManuallyDrop<Vec<T>>, you’ll drop the items with tmp.clear(), so you should use .set_len(0) instead to avoid the double-frees.
Yeah, I have to use tmp.clear() instead of tmp.set_len() to avoid double drop. Thanks for reminding me.
I didn't know copy_nonoverlapping existed. I think it's much better approach than for loop.
use std::mem::ManuallyDrop;
use std::ptr;
pub fn sort<T: Ord>(s: &mut [T]) {
if s.len() > 0 {
let mut tmp = ManuallyDrop::new(Vec::new());
mergesort_recursive(s, &mut tmp, 0, s.len() - 1);
unsafe {
//inner content of tmp should not be dropped.
tmp.set_len(0);
//drop tmp manually.
ManuallyDrop::drop(&mut tmp);
}
}
}
fn mergesort_recursive<T: Ord>(
s: &mut [T],
tmp: &mut ManuallyDrop<Vec<T>>,
start: usize,
end: usize,
) {
if end.saturating_sub(start) <= 0 {
return;
}
let mid = (start + end) / 2;
mergesort_recursive(s, tmp, start, mid);
mergesort_recursive(s, tmp, mid + 1, end);
merge(s, tmp, start, mid, end);
}
fn merge<T: Ord>(
s: &mut [T],
tmp: &mut ManuallyDrop<Vec<T>>,
start: usize,
mid: usize,
end: usize,
) {
assert!(start <= mid && mid + 1 <= end);
unsafe {
//content of tmp array should not be dropped.
tmp.set_len(0);
}
let mut i = start;
let mut j = mid + 1;
while i <= mid && j <= end {
//It may panic while comparing.
//It is safe to painc while comparing. Double drop should't happen since tmp vector is ManuallyDrop.
if s[i] < s[j] {
tmp.push(unsafe { ptr::read(&s[i] as *const T) });
i += 1;
} else {
tmp.push(unsafe { ptr::read(&s[j] as *const T) });
j += 1;
}
}
while i <= mid {
tmp.push(unsafe { ptr::read(&s[i] as *const T) });
i += 1;
}
while j <= end {
tmp.push(unsafe { ptr::read(&s[j] as *const T) });
j += 1;
}
//check if length of tmp is equal to input array for safety.
assert_eq!(tmp.len(), end - start + 1);
//copy tmp array to original array.
unsafe {
ptr::copy_nonoverlapping(tmp.as_mut_ptr(), s.as_mut_ptr().add(start), tmp.len());
}
}