How can I avoid self-reference in my code?

I want spawn a thread in a struct.
The following Rust code:

use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};
#[derive(Debug)]
pub struct DataCollector {
    connected: AtomicBool,
    port: HANDLE,
    completion: HANDLE,
    thread_count: u32,
    work_threads: Vec<std::thread::JoinHandle<()>>,
    self_ref: Option<Arc<Weak<Mutex<DataCollector>>>>,
}
impl DataCollector {
    /// `new` return `Arc<Mutex<DataCollector>>`
    pub fn new() -> Arc<Mutex<DataCollector>> {
        let data_collector = Arc::new(Mutex::new(Self {
            connected: AtomicBool::new(false),
            port: INVALID_HANDLE_VALUE as HANDLE,
            completion: INVALID_HANDLE_VALUE as HANDLE,
            thread_count: 0,
            work_threads: Vec::new(),
            self_ref: None,
        }));

        let weak_data_collector = Arc::downgrade(&data_collector);
        data_collector.lock().unwrap().self_ref = Some(weak_data_collector.into());
        data_collector
    }

    /// `run`  workthread created
    unsafe fn run(&mut self) {
        for _ in 0..self.thread_count {
            let self_weak = self.self_ref.clone().unwrap();
            let handle = std::thread::spawn(move || {
                if let Some(data_collect_arc) = self_weak.upgrade() {
                    let mut data_collect = data_collect_arc.lock().unwrap();
                    data_collect.work_thread();
                }
            });
            // skip...
    }

    pub unsafe fn work_thread(&mut self){
        // skip...
    }
}

How can I start a thread without using self-reference and take ownership?

Post the full error message from running cargo build in the terminal.

There is no error, I just try to find a method to avoiding using self-ref.

And using self-ref will decrease the readability of the code.

You have two functions:

    unsafe fn run(&mut self) {...}
pub unsafe fn work_thread(&mut self){...}

that are not in an impl block. Functions that have a self parameter must be in an impl block for the Self type.

If you don't want them to be in an impl block, change the self parameter to a parameter of the type you want.

These functions are also declared unsafe which looks like a different mistake.

That's all I can tell you with the information you've given. I don't think you're using the terms "self-reference" and "ownership" correctly, but it is very difficult to tell.

Here is documentation about how to define methods, with and without a Self parameter.

sorry about it, I make a mistake. now I updated it.

I don't know what question you're asking. But if you post code that can compile, then perhaps I'll see from the errors what it is you need. The code you posted (latest version) doesn't compile because some types are missing.

I just try to avoid to use self_ref: Option<Arc<Weak<Mutex<DataCollector>>>>, it is too complex.
And when I spawn the thread, I do want not take the ownership. It wiill decrease the readability of the code.

Oh, self_ref is a field. You mispelled it in some cases and used self-reference in other cases, so I didn't know what you meant. I will try to take a look tomorrow. Your question is a little more clear now.

But also please describe what you mean by "when I spawn the thread, I do want not take the ownership". What does "take ownership" mean to you? I don't know if you mean ownership in the way it is defined in Rust. And ownership of what, exactly?

"when I spawn the thread, I do want not take the ownership" which means I do want not using

let self_weak = self.self_ref.clone().unwrap();

way to create a new thread by move self_weak.

                    let handle = std::thread::spawn(move || {
                        if let Some(data_collect_arc) = self_weak.upgrade() {
                            let mut data_collect =
                                data_collect_arc.lock().unwrap();
                            data_collect.work_thread();
                        }
                    });

Because you are using a Mutex to lock the DataCollector before doing any work in each thread, each thread will not run at the same time as (in parallel with) other threads. Each thread will run to completion before running the other threads. So the threads are not providing any benefit.

When using threads to divide work, you should also divide the mutable data used by those threads. You haven't shown what work the threads are doing, since the work_thread method is empty. I assume they will collect data. If that is true, each thread can return the results that are collected, and the results can be combined by the parent thread. But all this is something you have not shown.

Because of the code you wrote, I assume you want to share all the fields of the DataCollector with all of the threads. But this doesn't make sense, since it seems many of its fields should only be set once (connected, port and completion).

Perhaps you need to make a separate struct for the worker threads, perhaps called ThreadWorker, and create an instance of that struct for each thread. That way you could have separate mutable data for each thread. Then you don't need to lock the DataCollector in each thread. Instead you give each thread it's own data, contained in ThreadWorker. And when the threads are finished, they can return their data to the parent thread. Or if they return their work incrementally, a little at a time, you can use channels to send the data back to the parent thread.

I can only guess. To be honest I am afraid your written English communication is not good enough for someone to help you here. If you can find a coworker or friend to help, that may be better. Or if you want to continue to discuss this here, you will need to provide much more information about what you're trying to do. And when you post code, you need to make sure it compiles -- otherwise people will not want to help you. Even the latest version of your code above does not compile. Right now you're only communicating very small bits of information, which makes it impossible to understand how to help you.

1 Like

Unfortunately, I am the only one in my team who uses Rust.

Can you briefly describe the work that each thread should do, and how the result of that work will be used? If so, I may be able to help a little.

It is Ok to use a language translator, if that makes it easier. I hear that this is a good one: DeepL Translate: The world's most accurate translator

1 Like

The code works now, I just want to optimize it, and the point of optimization is that I want to remove the self-referencing in the code to start a worker thread. Although I used std::sync::weak to avoid circular references, the readability of the current code is horrible!
There is the full code:

// datacollect.rs
use std::mem::MaybeUninit;
use std::os::raw::c_void;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};

use widestring::{WideCString, WideChar};

use crate::collect::{DataCollectorEvent, DataCollectorEventMessage};
use windows_sys::core::{HRESULT, PCWSTR};
use windows_sys::Win32::Foundation::{
    CloseHandle, GetLastError, BOOL, ERROR_IO_PENDING, FALSE, HANDLE, INVALID_HANDLE_VALUE,
    SEVERITY_ERROR, S_OK, TRUE,
};
use windows_sys::Win32::Storage::InstallableFileSystems::{
    FilterConnectCommunicationPort, FilterGetMessage, FILTER_MESSAGE_HEADER,
};
use windows_sys::Win32::System::Diagnostics::Debug::FACILITY_WIN32;
use windows_sys::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO};
use windows_sys::Win32::System::Threading::INFINITE;
use windows_sys::Win32::System::IO::{
    CreateIoCompletionPort, GetQueuedCompletionStatus, OVERLAPPED,
};

use crate::common::{ActionType, DataCollectorStatus};

use tracing::{error, info, instrument, trace};

/// `hresult_from_win32` C++ macro: HRESULT_FROM_WIN32
#[inline(always)]
fn hresult_from_win32(x: u32) -> HRESULT {
    if x as i32 <= 0 {
        x as HRESULT
    } else {
        ((x & 0x0000FFFF) | (FACILITY_WIN32 << 16) | SEVERITY_ERROR << 31) as HRESULT
    }
}

/// `DataCollector` 
#[derive(Debug)]
pub struct DataCollector {
    connected: AtomicBool,
    port: HANDLE,
    completion: HANDLE,
    thread_count: u32,
    work_threads: Vec<std::thread::JoinHandle<()>>,
    self_ref: Option<Arc<Weak<Mutex<DataCollector>>>>,
}

impl DataCollector {
    /// `new`  return  `Arc<Mutex<DataCollector>>`
    pub fn new() -> Arc<Mutex<DataCollector>> {
        let data_collector = Arc::new(Mutex::new(Self {
            connected: AtomicBool::new(false),
            port: INVALID_HANDLE_VALUE as HANDLE,
            completion: INVALID_HANDLE_VALUE as HANDLE,
            thread_count: 0,
            work_threads: Vec::new(),
            self_ref: None,
        }));

        let weak_data_collector = Arc::downgrade(&data_collector);
        data_collector.lock().unwrap().self_ref = Some(weak_data_collector.into());
        data_collector
    }

    /// `start` 
    pub fn start(&mut self) -> bool {
        if self.is_connected() {
            return true;
        }
        match self.connect() {
            Some(DataCollectorStatus::DataCollectorConnectSuccess) => {
                return self.listen();
            }
            Some(DataCollectorStatus::DataCollectorConnectFail) => {
                error!("DataCollector start connect fail");
                return false;
            }
            None => {
                error!("DataCollector start connect fail: None");
                return false;
            }
        }
    }

    /// `stop`
    pub fn stop(&mut self) -> bool {
        info!("DataCollector stop connect and port");
        unsafe {
            if self.port != INVALID_HANDLE_VALUE as HANDLE {
                CloseHandle(self.port as HANDLE);
            }
            if self.completion != INVALID_HANDLE_VALUE as HANDLE {
                CloseHandle(self.completion as HANDLE);
            }
        }
        if self.work_threads.len() > 0 {
            for handle in self.work_threads.drain(..) {
                handle.join().unwrap();
            }
            trace!("DataCollector wait all work threads exit");
        }
        true
    }

    /// `is_connected` 
    pub fn is_connected(&self) -> bool {
        return self.connected.load(Ordering::Acquire);
    }

    /// `connect` 
    #[instrument]
    pub fn connect(&mut self) -> Option<DataCollectorStatus> {
        let mut port: HANDLE = INVALID_HANDLE_VALUE as HANDLE;
       
        let connection_context = WideCString::from_str("xxx")
            .expect("Failed to convert string to wide null-terminated string");
        let port_name = WideCString::from_str("\\xxx")
            .expect("Failed to convert string to wide null-terminated string");
        
        #[allow(unused_assignments)]
        let mut h_result: HRESULT = S_OK;
        unsafe {
            h_result = FilterConnectCommunicationPort(
                port_name.as_ptr() as PCWSTR,
                0,
                connection_context.as_ptr() as *const c_void,
                (connection_context.len() * std::mem::size_of::<WideChar>()) as u16,
                std::ptr::null(),
                &mut port,
            );

            if h_result != S_OK {
                error!(
                    "Connect FilterConnectCommunicationPort fail! error code = 0x{:x}",
                    h_result
                );
                return Some(DataCollectorStatus::DataCollectorConnectFail);
            }
        }
        trace!("connect to mini filter port success");
        self.port = port as HANDLE;
        self.connected.store(true, Ordering::Release);

        Some(DataCollectorStatus::DataCollectorConnectSuccess)
    }

    /// `listen`
    #[instrument]
    pub fn listen(&mut self) -> bool {
        unsafe {
            let mut sys_info_uninit = MaybeUninit::<SYSTEM_INFO>::uninit();
            std::ptr::write(sys_info_uninit.as_mut_ptr(), std::mem::zeroed());
            let mut sys_info = sys_info_uninit.assume_init();

            GetSystemInfo(&mut sys_info);
            self.thread_count = (sys_info.dwNumberOfProcessors * 2).min(16);
            self.completion = CreateIoCompletionPort(self.port, 0, 0, self.thread_count);

            if self.completion == INVALID_HANDLE_VALUE {
                error!(
                    "CreateIoCompletionPort fail! error code = {}",
                    GetLastError()
                );
                return false;
            }
        }
        unsafe {
            self.run();
        };
        true
    }

    /// `run` 
    unsafe fn run(&mut self) {
        for _ in 0..self.thread_count {
            let self_weak = self.self_ref.clone().unwrap();
            let handle = std::thread::spawn(move || {
                if let Some(data_collect_arc) = self_weak.upgrade() {
                    let mut data_collect = data_collect_arc.lock().unwrap();
                    data_collect.work_thread();
                }
            });

            let mut message_unit = MaybeUninit::<DataCollectorEventMessage>::uninit();
            std::ptr::write(message_unit.as_mut_ptr(), std::mem::zeroed());
            let mut message = message_unit.assume_init();
            #[allow(unused)]
            let mut hr: HRESULT = S_OK;
            hr = FilterGetMessage(
                self.port,
                &mut message.message_header as *mut FILTER_MESSAGE_HEADER,
                (std::mem::size_of::<FILTER_MESSAGE_HEADER>()
                    + std::mem::size_of::<DataCollectorEvent>()) as u32,
                &mut message.ovlp as *mut OVERLAPPED,
            );
            if hr != hresult_from_win32(ERROR_IO_PENDING) {
                error!("DataCollector FilterGetMessage error,it is not ERROR_IO_PENDING, h_result:0x{:x}",hr);
                return;
            }
            self.work_threads.push(handle);
        }
    }

    /// `work_thread` 
    pub unsafe fn work_thread(&mut self) {
        #[allow(unused_assignments)]
        let mut hr: HRESULT = S_OK;
        let mut message_unit = MaybeUninit::<DataCollectorEventMessage>::uninit();
        std::ptr::write(message_unit.as_mut_ptr(), std::mem::zeroed());
        let mut message = message_unit.assume_init();

        let mut ovlp_inner_unit = MaybeUninit::<OVERLAPPED>::uninit();
        std::ptr::write(ovlp_inner_unit.as_mut_ptr(), std::mem::zeroed());
        let mut ovlp_inner = ovlp_inner_unit.assume_init();
        let mut lp_ovlp_inner: *mut OVERLAPPED = &mut ovlp_inner as *mut OVERLAPPED;
        let ovlp_addr = std::ptr::addr_of_mut!(lp_ovlp_inner);

        #[allow(unused_assignments)]
        let mut result: BOOL = FALSE;
        let mut out_size = 0u32;
        #[allow(unused_mut)]
        let mut key: usize = 0;

        info!("DataCollector work thread listen");

        loop {
            result = GetQueuedCompletionStatus(
                self.completion,
                std::ptr::addr_of_mut!(out_size),
                std::ptr::addr_of_mut!(key),
                ovlp_addr,
                INFINITE,
            );
            if result != TRUE {
                hr = hresult_from_win32(GetLastError());
                error!(
                    "DataCollector iocp work thread exited with error code : 0x{0:x}",
                    hr as u32
                );
                return;
            } else {
                info!("GetQueuedCompletionStatus success");
            }
            let raw_message = lp_ovlp_inner as *const DataCollectorEventMessage;

            let message_borrowed: &DataCollectorEventMessage = &*raw_message;
            Self::dispatch(message_borrowed);
            hr = FilterGetMessage(
                self.port,
                &mut message.message_header as *mut FILTER_MESSAGE_HEADER,
                (std::mem::size_of::<FILTER_MESSAGE_HEADER>()
                    + std::mem::size_of::<DataCollectorEvent>()) as u32,
                &mut message.ovlp as *mut OVERLAPPED,
            );
            if hr != hresult_from_win32(ERROR_IO_PENDING) {
                error!("DataCollector WorkThread FilterGetMessage error,is not ERROR_IO_PENDING, h_result:0x{:x}",hr);
                return;
            }
        }
        // maybe todo exit work thread
    }

    // `dispatch` 
    fn dispatch(data: &DataCollectorEventMessage) {
        let act_type = &data.body.action_type;
        if act_type == &ActionType::ActionRun {
            unsafe { Self::receive_app_event(data) };
        }

        if act_type == &ActionType::ActionCreate
            || act_type == &ActionType::ActionWrite
            || act_type == &ActionType::ActionMove
            || act_type == &ActionType::ActionDelete
        {
            unsafe { Self::receive_file_io_event(data) };
        }
    }

    /// `receive_app_event` 
    unsafe fn receive_app_event(data: &DataCollectorEventMessage) {
        let process_name: WideCString = WideCString::from_vec_unchecked(
            data.body
                .info
                .exe_info
                .app_path
                .iter()
                .copied()
                .take_while(|&c| c != 0)
                .collect::<Vec<u16>>(),
        );
        let command_line: WideCString = WideCString::from_vec_unchecked(
            data.body
                .info
                .exe_info
                .cmd_line
                .iter()
                .copied()
                .take_while(|&c| c != 0)
                .collect::<Vec<u16>>(),
        );
        info!(
            "app : {:?}, parent pid : {}, cmdline : {:?}",
            process_name, data.body.parent_pid, command_line
        );
        // TODO datamanager
    }

    /// `receive_file_io_event`
    unsafe fn receive_file_io_event(data: &DataCollectorEventMessage) {
        let pid = data.body.pid;
        let io_path: WideCString = WideCString::from_vec_unchecked(
            data.body
                .info
                .io_info
                .path
                .iter()
                .copied()
                .take_while(|&c| c != 0)
                .collect::<Vec<u16>>(),
        );
        info!(
            "pid : {:?}, parent pid : {}, io_path : {:?}",
            pid, data.body.parent_pid, io_path
        );
    }
}

// mod.rs
pub mod data_collect;
use crate::common::ActionType;
use crate::common::DataCollectorCmd;
use widestring::WideChar;

use windows_sys::Win32::Storage::InstallableFileSystems::FILTER_MESSAGE_HEADER;
use windows_sys::Win32::System::IO::OVERLAPPED;

const NOAH_DATA_DETECT_MAX_PATH: usize = 1024;

#[repr(C)]
pub struct DataCollectorRequest {
    pub version: i32,
    pub msg_id: i32,
    pub cmd: DataCollectorCmd,
}

#[derive(Clone, Copy)]
#[repr(C)]
pub struct DataCollectorEvent {
    pub action_type: ActionType,
    pub pid: u32,
    pub parent_pid: u32,
    pub info: InfoUnion,
}

#[derive(Clone, Copy)]
#[repr(C)]
pub struct DataCollectorEventMessage {
    pub ovlp: OVERLAPPED,
    pub message_header: FILTER_MESSAGE_HEADER,
    pub body: DataCollectorEvent,
}

unsafe impl Send for DataCollectorEventMessage {}
unsafe impl Sync for DataCollectorEventMessage {}

#[derive(Clone, Copy)]
#[repr(C)]
pub union InfoUnion {
    pub io_info: IoInfo,
    pub exe_info: ExeInfo,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct IoInfo {
    pub path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
    pub dest_path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
}

#[derive(Clone, Copy)]
#[repr(C)]
pub struct ExeInfo {
    pub app_path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
    pub cmd_line: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
}

I just want to remove the self-referencing.

You can make the code a bit prettier by wrapping up the unsafe in nicer, safe interfaces, using Drop to automatically clean up resources, etc.; and using windows instead of windows-sys to get nicer ergonomics on arguments and results. Otherwise, that's basically just what FFI looks like!

Direct access to IOCP isn't well supported right now in the library ecosystem (there's well established libraries that use it internally like mio, and there's early libraries like tokio-iocp that might be cool in the future)

As to how to self reference... don't is the short answer!

In your current code, you essentially fire and forget the worker threads with only the IOCP handle and the port id being accessed, and no response being fed back to the DataCollector (yet).

Instead of needing to pass in an Arc to the whole thing, create a worker type with just what the thread needs, that you can cheaply clone and move that into a thread. (Possibly using Arc, but I don't think you need it here)

When you need to get the results back from the worker, the simple options from the standard library are

  • return it from the thread, which you can get out of the join handle,
  • include an std::sync::mpsc Sender in the worker data, which each writer can send data to, and keep the Receiver in the main thread, which you can read in a loop.

There's also a bunch of more efficient libraries out there if you are doing higher volume data, make sure to make good use of lib.rs!

3 Likes

I understood that, but I had the feeling you were doing something wrong based on the code you posted, so I needed more info. Thanks for posting the full code.

One of the things I mentioned was that you should probably be using a separate struct for the worker threads, as @simonbuchan also mentioned. This seems clears now from the code you posted, and I think the answer from @simonbuchan is correct.