This is a follow-up question of the following: Secure arbitrary code execution PyO3
Long story short: I have a app where a user can upload different Python scripts containing trading algorithms to trade on the Binance API based on candlesticks. Initially I executed these Python scripts using PyO3 in my main application. But because it is arbitrary code I want to secure it on OS-level (e.g restrict network-access), which is not possible if my main application is the process executing it.
So I have made a separate binary for executing the Python scripts. I use IPC-communication to send the data from my main app to the PyExecutor.
When the algorithm is started by the user, first some historic candlesticks are retrieved. I pass these to the PyExecutor using shared memory.
Next, my main app receives the most recent candlesticks from the Binance API as websocket, each time the websocket sends data I forward that data to my PyExecutor using UnixSockets.
Code in main-app when algorithm is started:
pub async fn start(self, psql: Psql, datastream: Arc<Mutex<mpsc::Receiver<CandleStick>>>, api: Api, ws_send: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>) -> Result<std::process::Child, tradealgorithm::Error> {
// Variable to save initial prepended kline data. We serialize this later
// to share it through shared memory.
let mut data = std::vec::Vec::<CandleStick>::new();
// Prepend data if required.
if self.prepend_data > 0 {
let endtime = std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_millis();
let starttime = endtime - self.prepend_data as u128;
// Create request parameters.
let mut params = std::collections::HashMap::<String, String>::new();
params.insert("symbol".into(), "BTCUSDT".into());
params.insert("interval".into(), self.interval.to_string());
params.insert("startTime".into(), starttime.to_string());
params.insert("endTime".into(), endtime.to_string());
// Execute request.
match api.klines(&mut params).await {
Ok(d) => {
for kl in d {
data.push(CandleStick {
timestamp: serde_json::from_value::<u64>(kl[0].clone())?,
open: serde_json::from_value::<String>(kl[1].clone())?.parse::<f64>()?,
high: serde_json::from_value::<String>(kl[2].clone())?.parse::<f64>()?,
low: serde_json::from_value::<String>(kl[3].clone())?.parse::<f64>()?,
close: serde_json::from_value::<String>(kl[4].clone())?.parse::<f64>()?,
volume: serde_json::from_value::<String>(kl[5].clone())?.parse::<f64>()?,
});
}
},
Err(_) => {
return Err(tradealgorithm::Error::APIError("Could not prepend data".into()));
}
};
}
// Write prepended data to shared memory.
let shmem_path = format!("tmp/shmem/{}.bin", self.id);
std::fs::remove_file(shmem_path).unwrap_or_default();
let serialized_data = serde_json::to_string(&data).expect("Failed to serialize data.");
let data_size = serialized_data.len();
let memfile = std::fs::OpenOptions::new()
.read(true)
.write(true)
.create(true)
.open(format!("tmp/shmem/{}.bin", self.id))
.expect("Failed to open file");
memfile.set_len(data_size as u64)
.expect("Failed to set file size");
let mut mapped_data = unsafe {
MmapOptions::new()
.len(data_size)
.map_mut(&memfile)
.expect("Failed to map file into memory")
};
mapped_data[..data_size].copy_from_slice(serialized_data.as_bytes());
// Create a new process to execute algorithm.
let process_handle = std::process::Command::new("target/debug/python-executor")
.args(&[self.id.to_string()])
.spawn()?;
// Create a stream so we can write data to the PyExecutor and receive the
// result back.
let unix_socket_path = &*format!("tmp/sockets/{}.sock", self.id);
// Wait 3 seconds for socket connection.
for i in 0..3 {
if std::fs::metadata(unix_socket_path).is_ok() {
break;
}
if i >= 2 {
return Err(tradealgorithm::Error::StreamError("Could not connect to stream.".into()));
}
tokio::time::sleep(tokio::time::Duration::from_millis(1000)).await;
}
// Connect to UnixSocket and split into receiver and transmitter so we can
// send data from the API to PyExecutor and receive the result.
let unix_stream = UnixStream::connect(unix_socket_path).await?;
let (mut rx, mut tx) = unix_stream.into_split();
// Thread to receive data from API (websocket) and send it to UnixSocket.
tokio::spawn(async move {
while let Some(n) = datastream.lock().await.recv().await {
let ready = tx.ready(Interest::READABLE | Interest::WRITABLE).await.expect("Could not check if UnixSocket is ready to write.");
if ready.is_writable() {
match tx.write_all(serde_json::json!(n).to_string().as_bytes()).await {
Ok(_) => (),
Err(e) => {
if e.kind() != std::io::ErrorKind::BrokenPipe {
eprintln!("Failed to write: {}", e);
}
panic!("Error writing to stream");
}
}
}
}
});
// Thread to read result from PyExecutor.
tokio::spawn(async move {
loop {
let mut buffer = [0; 1024];
match rx.read(&mut buffer).await {
Ok(0) => {
break;
},
Ok(n) => {
// Data structure we expect receive.
#[derive(Deserialize)]
struct Data {
result: f64,
last_candlestick: CandleStick,
}
let received_data = String::from_utf8_lossy(&buffer[..n]);
let data : Data = serde_json::from_str(&received_data.to_string()).expect("Failed to deserialze");
// Process result and execute order if necessary.
match self.process(psql.clone(), (data.result, data.last_candlestick.close), api.clone(), ws_send.clone()).await {
Ok(_) => (),
Err(e) => {
eprintln!("Error processing {}", e);
}
}
},
Err(e) => {
eprintln!("Error {}", e);
}
}
}
});
Ok(process_handle)
}
PyExecutor (separate binary receiving data from main and sending result of Python script back):
use pyo3::prelude::*;
use serde::{Serialize, Deserialize};
use memmap2:: MmapOptions;
use tokio::net::UnixListener;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use pyo3::types::PyList;
#[derive(Clone, Debug, Serialize, Deserialize)]
#[pyclass]
pub struct CandleStick {
#[pyo3(get, set, name = "t")]
pub timestamp: u64,
#[pyo3(get, set, name = "o")]
pub open: f64,
#[pyo3(get, set, name = "c")]
pub close: f64,
#[pyo3(get, set, name = "h")]
pub high: f64,
#[pyo3(get, set, name = "l")]
pub low: f64,
#[pyo3(get, set, name = "v")]
pub volume: f64,
}
#[tokio::main]
async fn main() -> Result<(), Error> {
let args: std::vec::Vec<String> = std::env::args().collect();
let algorithm_id = args[1].to_string();
// Retrieve data from shared memory.
let shmem_path = &*format!("tmp/shmem/{}.bin", algorithm_id);
let memfile = std::fs::File::open(shmem_path).expect("Failed to open memfile.");
let memfile_metadata = memfile.metadata().expect("Failed to get metadata memfile.");
let mapped_data = unsafe {
MmapOptions::new()
.len(memfile_metadata.len() as usize)
.map(&memfile)
.expect("Failed to map file to memory.")
};
let serialized_data = &mapped_data[..memfile_metadata.len() as usize];
// Create vec of klines from data in shared memory.
let data : std::vec::Vec<CandleStick> = serde_json::from_slice(serialized_data).expect("Failed to deserialize.");
// Create UnixSocket to receive data from algorithm and to send result back.
let unix_socket_path = &*format!("tmp/sockets/{}.sock", algorithm_id);
std::fs::remove_file(unix_socket_path).unwrap_or_default();
let listener = match UnixListener::bind(unix_socket_path) {
Ok(lis) => lis,
Err(e) => {
return Err(Error::StreamError(format!("Could not connect to UnixSocket: {}", e)));
}
};
// Accept connection.
while let Ok((mut stream, _)) = listener.accept().await {
let mut data_clone = data.clone();
let algorithm_id_clone = algorithm_id.clone();
// Don't close connection. Keep reading.
tokio::spawn(async move {
loop {
let mut buffer = [0; 1024];
match stream.read(&mut buffer).await {
Ok(0) => {
break;
},
Ok(n) => {
// Data structure we are sending back.
#[derive(Serialize)]
struct Data {
result: f64,
last_candlestick: CandleStick,
}
// kline received from UnixSocket.
let received_data = String::from_utf8_lossy(&buffer[..n]);
let candlestick : CandleStick = match serde_json::from_str(&received_data) {
Ok(c) => c,
Err(_) => {
continue;
}
};
// Add received kline to all data.
data_clone.push(candlestick.clone());
// Execute the Python code with the given data.
let result = match execute(data_clone.clone(), algorithm_id_clone.to_string()).await {
Ok(r) => r,
Err(e) => {
panic!("Could not execute PythonCode: {}", e);
}
};
// Write result of Python code back to UnixSocket.
stream.write_all(serde_json::json!(Data {
result: result,
last_candlestick: candlestick,
}).to_string().as_bytes()).await.expect("Failed to send");
},
Err(_) => {
break;
}
}
}
});
}
Ok(())
}
// Execute the Python code.
async fn execute(data: std::vec::Vec<CandleStick>, python_file: String) -> Result<f64, String> {
pyo3::prepare_freethreaded_python();
// Retrieve Python code from file.
let mut python_code = match std::fs::read_to_string(format!("trading_algos/{}.py", python_file)) {
Ok(code) => code,
Err(e) => {
return Err("Error".into());
}
};
// Check if Python code contains blacklisted keywords.
if !arbitrary_code_is_secure(&python_code) {
return Err("Error".into());
}
// Import allowed libraries into Python code.
import_allowed_libraries(&mut python_code);
// Call Python function func.
let result = Python::with_gil(|py| {
// Convert data of klines to a PyList so it can be passed to Python code.
let py_candlesticks: Vec<Py<CandleStick>> = data.into_iter().map(|candlestick| {
Py::new(py, candlestick.clone()).unwrap()
}).collect();
let py_list = PyList::new(py, py_candlesticks);
// Create PyModule.
let fun: Py<PyAny> = match PyModule::from_code(
py,
&*python_code,
"",
"",
) {
Ok(f) => {
match f.getattr("func") {
Ok(a) => a.into(),
Err(e) => {
return Err("Error".to_string());
}
}
},
Err(e) => {
return Err("Error".into());
}
};
// Call Python function.
let f = match fun.call1(py, (py_list,)) {
Ok(f) => f,
Err(e) => {
return Err("Error".into());
}
};
let result = match f.extract::<f64>(py) {
Ok(r) => r,
Err(e) => {
return Err("Error".into());
}
};
Ok(result)
});
let result = match result {
Ok(r) => r,
Err(e) => {
return Err("Error".into());
}
};
Ok(result)
}
// Import allowed libraries into the Python code.
fn import_allowed_libraries(code: &mut String) {
let allowed_libraries = vec![
"math",
"numpy",
"pandas",
];
for lib in allowed_libraries {
code.insert_str(0, &*format!("import {}\n", lib));
}
}
// We execute arbitrary Python scripts. Even if this is supposed to be a internal
// application we need to secure this somehow. We blacklist specific Python keywords like
// "import", "exec", "os",... so no dangerous code can be executed.
// We import popular mathematical libraries so those can be used.
fn arbitrary_code_is_secure(code: &String) -> bool {
let blacklist = vec![
"import",
"open",
"read",
"write",
"file",
"os",
"exec",
"eval",
"socket",
"http",
"requests",
"urllib",
"sys",
"traceback",
"__",
];
let re_find_commented_lines = regex::RegexBuilder::new(r"^\s*#.*\n?").multi_line(true).build().unwrap();
//let code_without_comments = re_find_commented_lines.replace_all(code, "");
let code_without_comments = re_find_commented_lines.replace_all(code, |caps: ®ex::Captures| {
if caps.get(0).unwrap().as_str().contains('\'') {
caps.get(0).unwrap().as_str().to_string()
} else {
"".to_string()
}
});
for keyword in blacklist {
match code_without_comments.find(keyword) {
Some(_) => { return false; },
None => (),
}
}
true
}
// Error type for PyExecutor.
#[derive(Debug)]
pub enum Error {
PyExecutorError(String),
ParseError(String),
PythonCodeError(String),
StreamError(String),
}
impl std::error::Error for Error {}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Self {
Error::ParseError(e.to_string())
}
}
impl From<std::io::Error> for Error {
fn from(e: std::io::Error) -> Self {
Error::PyExecutorError(e.to_string())
}
}
impl From<std::num::ParseFloatError> for Error {
fn from(e: std::num::ParseFloatError) -> Self {
Error::ParseError(format!("line: {} {}", line!(), e.to_string()))
}
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Error::PyExecutorError(error_msg) => write!(f, "\x1b[31m[Error] PyExecuto - PyExecutorError: {}\x1b[0m", error_msg),
Error::ParseError(error_msg) => write!(f, "\x1b[31m[Error] PyExecutor - ParseError: {}\x1b[0m", error_msg),
Error::PythonCodeError(error_msg) => write!(f, "\x1b[31m[Error] PyExecutor - PythonCodeError: {}\x1b[0m", error_msg),
Error::StreamError(error_msg) => write!(f, "\x1b[31m[Error] PyExecutor - StreamError: {}\x1b[0m", error_msg),
}
}
}
This code seems to work. But I was wondering if there are improvements?