The code works now, I just want to optimize it, and the point of optimization is that I want to remove the self-referencing in the code to start a worker thread. Although I used std::sync::weak to avoid circular references, the readability of the current code is horrible!
There is the full code:
// datacollect.rs
use std::mem::MaybeUninit;
use std::os::raw::c_void;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, Weak};
use widestring::{WideCString, WideChar};
use crate::collect::{DataCollectorEvent, DataCollectorEventMessage};
use windows_sys::core::{HRESULT, PCWSTR};
use windows_sys::Win32::Foundation::{
CloseHandle, GetLastError, BOOL, ERROR_IO_PENDING, FALSE, HANDLE, INVALID_HANDLE_VALUE,
SEVERITY_ERROR, S_OK, TRUE,
};
use windows_sys::Win32::Storage::InstallableFileSystems::{
FilterConnectCommunicationPort, FilterGetMessage, FILTER_MESSAGE_HEADER,
};
use windows_sys::Win32::System::Diagnostics::Debug::FACILITY_WIN32;
use windows_sys::Win32::System::SystemInformation::{GetSystemInfo, SYSTEM_INFO};
use windows_sys::Win32::System::Threading::INFINITE;
use windows_sys::Win32::System::IO::{
CreateIoCompletionPort, GetQueuedCompletionStatus, OVERLAPPED,
};
use crate::common::{ActionType, DataCollectorStatus};
use tracing::{error, info, instrument, trace};
/// `hresult_from_win32` C++ macro: HRESULT_FROM_WIN32
#[inline(always)]
fn hresult_from_win32(x: u32) -> HRESULT {
if x as i32 <= 0 {
x as HRESULT
} else {
((x & 0x0000FFFF) | (FACILITY_WIN32 << 16) | SEVERITY_ERROR << 31) as HRESULT
}
}
/// `DataCollector`
#[derive(Debug)]
pub struct DataCollector {
connected: AtomicBool,
port: HANDLE,
completion: HANDLE,
thread_count: u32,
work_threads: Vec<std::thread::JoinHandle<()>>,
self_ref: Option<Arc<Weak<Mutex<DataCollector>>>>,
}
impl DataCollector {
/// `new` return `Arc<Mutex<DataCollector>>`
pub fn new() -> Arc<Mutex<DataCollector>> {
let data_collector = Arc::new(Mutex::new(Self {
connected: AtomicBool::new(false),
port: INVALID_HANDLE_VALUE as HANDLE,
completion: INVALID_HANDLE_VALUE as HANDLE,
thread_count: 0,
work_threads: Vec::new(),
self_ref: None,
}));
let weak_data_collector = Arc::downgrade(&data_collector);
data_collector.lock().unwrap().self_ref = Some(weak_data_collector.into());
data_collector
}
/// `start`
pub fn start(&mut self) -> bool {
if self.is_connected() {
return true;
}
match self.connect() {
Some(DataCollectorStatus::DataCollectorConnectSuccess) => {
return self.listen();
}
Some(DataCollectorStatus::DataCollectorConnectFail) => {
error!("DataCollector start connect fail");
return false;
}
None => {
error!("DataCollector start connect fail: None");
return false;
}
}
}
/// `stop`
pub fn stop(&mut self) -> bool {
info!("DataCollector stop connect and port");
unsafe {
if self.port != INVALID_HANDLE_VALUE as HANDLE {
CloseHandle(self.port as HANDLE);
}
if self.completion != INVALID_HANDLE_VALUE as HANDLE {
CloseHandle(self.completion as HANDLE);
}
}
if self.work_threads.len() > 0 {
for handle in self.work_threads.drain(..) {
handle.join().unwrap();
}
trace!("DataCollector wait all work threads exit");
}
true
}
/// `is_connected`
pub fn is_connected(&self) -> bool {
return self.connected.load(Ordering::Acquire);
}
/// `connect`
#[instrument]
pub fn connect(&mut self) -> Option<DataCollectorStatus> {
let mut port: HANDLE = INVALID_HANDLE_VALUE as HANDLE;
let connection_context = WideCString::from_str("xxx")
.expect("Failed to convert string to wide null-terminated string");
let port_name = WideCString::from_str("\\xxx")
.expect("Failed to convert string to wide null-terminated string");
#[allow(unused_assignments)]
let mut h_result: HRESULT = S_OK;
unsafe {
h_result = FilterConnectCommunicationPort(
port_name.as_ptr() as PCWSTR,
0,
connection_context.as_ptr() as *const c_void,
(connection_context.len() * std::mem::size_of::<WideChar>()) as u16,
std::ptr::null(),
&mut port,
);
if h_result != S_OK {
error!(
"Connect FilterConnectCommunicationPort fail! error code = 0x{:x}",
h_result
);
return Some(DataCollectorStatus::DataCollectorConnectFail);
}
}
trace!("connect to mini filter port success");
self.port = port as HANDLE;
self.connected.store(true, Ordering::Release);
Some(DataCollectorStatus::DataCollectorConnectSuccess)
}
/// `listen`
#[instrument]
pub fn listen(&mut self) -> bool {
unsafe {
let mut sys_info_uninit = MaybeUninit::<SYSTEM_INFO>::uninit();
std::ptr::write(sys_info_uninit.as_mut_ptr(), std::mem::zeroed());
let mut sys_info = sys_info_uninit.assume_init();
GetSystemInfo(&mut sys_info);
self.thread_count = (sys_info.dwNumberOfProcessors * 2).min(16);
self.completion = CreateIoCompletionPort(self.port, 0, 0, self.thread_count);
if self.completion == INVALID_HANDLE_VALUE {
error!(
"CreateIoCompletionPort fail! error code = {}",
GetLastError()
);
return false;
}
}
unsafe {
self.run();
};
true
}
/// `run`
unsafe fn run(&mut self) {
for _ in 0..self.thread_count {
let self_weak = self.self_ref.clone().unwrap();
let handle = std::thread::spawn(move || {
if let Some(data_collect_arc) = self_weak.upgrade() {
let mut data_collect = data_collect_arc.lock().unwrap();
data_collect.work_thread();
}
});
let mut message_unit = MaybeUninit::<DataCollectorEventMessage>::uninit();
std::ptr::write(message_unit.as_mut_ptr(), std::mem::zeroed());
let mut message = message_unit.assume_init();
#[allow(unused)]
let mut hr: HRESULT = S_OK;
hr = FilterGetMessage(
self.port,
&mut message.message_header as *mut FILTER_MESSAGE_HEADER,
(std::mem::size_of::<FILTER_MESSAGE_HEADER>()
+ std::mem::size_of::<DataCollectorEvent>()) as u32,
&mut message.ovlp as *mut OVERLAPPED,
);
if hr != hresult_from_win32(ERROR_IO_PENDING) {
error!("DataCollector FilterGetMessage error,it is not ERROR_IO_PENDING, h_result:0x{:x}",hr);
return;
}
self.work_threads.push(handle);
}
}
/// `work_thread`
pub unsafe fn work_thread(&mut self) {
#[allow(unused_assignments)]
let mut hr: HRESULT = S_OK;
let mut message_unit = MaybeUninit::<DataCollectorEventMessage>::uninit();
std::ptr::write(message_unit.as_mut_ptr(), std::mem::zeroed());
let mut message = message_unit.assume_init();
let mut ovlp_inner_unit = MaybeUninit::<OVERLAPPED>::uninit();
std::ptr::write(ovlp_inner_unit.as_mut_ptr(), std::mem::zeroed());
let mut ovlp_inner = ovlp_inner_unit.assume_init();
let mut lp_ovlp_inner: *mut OVERLAPPED = &mut ovlp_inner as *mut OVERLAPPED;
let ovlp_addr = std::ptr::addr_of_mut!(lp_ovlp_inner);
#[allow(unused_assignments)]
let mut result: BOOL = FALSE;
let mut out_size = 0u32;
#[allow(unused_mut)]
let mut key: usize = 0;
info!("DataCollector work thread listen");
loop {
result = GetQueuedCompletionStatus(
self.completion,
std::ptr::addr_of_mut!(out_size),
std::ptr::addr_of_mut!(key),
ovlp_addr,
INFINITE,
);
if result != TRUE {
hr = hresult_from_win32(GetLastError());
error!(
"DataCollector iocp work thread exited with error code : 0x{0:x}",
hr as u32
);
return;
} else {
info!("GetQueuedCompletionStatus success");
}
let raw_message = lp_ovlp_inner as *const DataCollectorEventMessage;
let message_borrowed: &DataCollectorEventMessage = &*raw_message;
Self::dispatch(message_borrowed);
hr = FilterGetMessage(
self.port,
&mut message.message_header as *mut FILTER_MESSAGE_HEADER,
(std::mem::size_of::<FILTER_MESSAGE_HEADER>()
+ std::mem::size_of::<DataCollectorEvent>()) as u32,
&mut message.ovlp as *mut OVERLAPPED,
);
if hr != hresult_from_win32(ERROR_IO_PENDING) {
error!("DataCollector WorkThread FilterGetMessage error,is not ERROR_IO_PENDING, h_result:0x{:x}",hr);
return;
}
}
// maybe todo exit work thread
}
// `dispatch`
fn dispatch(data: &DataCollectorEventMessage) {
let act_type = &data.body.action_type;
if act_type == &ActionType::ActionRun {
unsafe { Self::receive_app_event(data) };
}
if act_type == &ActionType::ActionCreate
|| act_type == &ActionType::ActionWrite
|| act_type == &ActionType::ActionMove
|| act_type == &ActionType::ActionDelete
{
unsafe { Self::receive_file_io_event(data) };
}
}
/// `receive_app_event`
unsafe fn receive_app_event(data: &DataCollectorEventMessage) {
let process_name: WideCString = WideCString::from_vec_unchecked(
data.body
.info
.exe_info
.app_path
.iter()
.copied()
.take_while(|&c| c != 0)
.collect::<Vec<u16>>(),
);
let command_line: WideCString = WideCString::from_vec_unchecked(
data.body
.info
.exe_info
.cmd_line
.iter()
.copied()
.take_while(|&c| c != 0)
.collect::<Vec<u16>>(),
);
info!(
"app : {:?}, parent pid : {}, cmdline : {:?}",
process_name, data.body.parent_pid, command_line
);
// TODO datamanager
}
/// `receive_file_io_event`
unsafe fn receive_file_io_event(data: &DataCollectorEventMessage) {
let pid = data.body.pid;
let io_path: WideCString = WideCString::from_vec_unchecked(
data.body
.info
.io_info
.path
.iter()
.copied()
.take_while(|&c| c != 0)
.collect::<Vec<u16>>(),
);
info!(
"pid : {:?}, parent pid : {}, io_path : {:?}",
pid, data.body.parent_pid, io_path
);
}
}
// mod.rs
pub mod data_collect;
use crate::common::ActionType;
use crate::common::DataCollectorCmd;
use widestring::WideChar;
use windows_sys::Win32::Storage::InstallableFileSystems::FILTER_MESSAGE_HEADER;
use windows_sys::Win32::System::IO::OVERLAPPED;
const NOAH_DATA_DETECT_MAX_PATH: usize = 1024;
#[repr(C)]
pub struct DataCollectorRequest {
pub version: i32,
pub msg_id: i32,
pub cmd: DataCollectorCmd,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct DataCollectorEvent {
pub action_type: ActionType,
pub pid: u32,
pub parent_pid: u32,
pub info: InfoUnion,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct DataCollectorEventMessage {
pub ovlp: OVERLAPPED,
pub message_header: FILTER_MESSAGE_HEADER,
pub body: DataCollectorEvent,
}
unsafe impl Send for DataCollectorEventMessage {}
unsafe impl Sync for DataCollectorEventMessage {}
#[derive(Clone, Copy)]
#[repr(C)]
pub union InfoUnion {
pub io_info: IoInfo,
pub exe_info: ExeInfo,
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct IoInfo {
pub path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
pub dest_path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
}
#[derive(Clone, Copy)]
#[repr(C)]
pub struct ExeInfo {
pub app_path: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
pub cmd_line: [WideChar; NOAH_DATA_DETECT_MAX_PATH],
}
I just want to remove the self-referencing.