I want to save the response to the database after streaming the entire response to the client. If there's an error while saving the data, I want the server to return a 500 error. I understand that changing the HTTP response after sent is not possible, but I don't want the streaming to complete before inserting the message into the database. I'm looking for a simple way to achieve this, as I'm still learning and experimenting.
The current code saves to the database before sending the response. If I use actix_web::rt::spawn, it sends the first chunk but then waits until the database save operation is complete before continuing with the rest of the response.
pub async fn groq_chat<'a>(message: &str, db_conn: &DatabaseConnection) -> HttpResponse {
let groq_api = env::var("GROQ_API").expect("GROQ_API must be set");
let client = Client::new();
let start_time = Instant::now();
let mut time_taken = String::new();
match client
.post("https://api.groq.com/openai/v1/chat/completions")
.bearer_auth(groq_api)
.header("Content-Type", "application/json")
.json(&json!({
"messages": [
{"role": "user", "content": message}
],
"model": "mixtral-8x7b-32768",
"stream": true
}))
.send()
.await
{
Ok(res) if res.status().is_success() => {
let mut content_buffer = String::new();
let mut stream = res.bytes_stream();
let mut accumulator = String::new();
let response_stream = stream! {
while let Some(item) = stream.next().await {
match item {
Ok(bytes) => {
let s = match std::str::from_utf8(&bytes) {
Ok(v) => v,
Err(e) => {
eprintln!("Invalid UTF-8 sequence: {}", e);
continue;
}
};
accumulator.push_str(s);
while let Some(pos) = accumulator.find("\n\n") {
let test = accumulator.clone();
let (chunk, rest) = test.split_at(pos);
accumulator = rest.trim_start_matches("\n\n").to_string();
if let Some(json_str) = chunk.strip_prefix("data: ") {
if json_str == "[DONE]" {
break;
}
match serde_json::from_str::<GroqChatResponse>(json_str) {
Ok(chat_response) => {
for choice in chat_response.choices {
if let Some(content) = choice.delta.content {
content_buffer.push_str(&content);
}
}
yield Ok(Bytes::from(content_buffer.clone()));
content_buffer.clear();
}
Err(e) => {
eprintln!("Failed to parse JSON: {}", e);
}
}
}
}
}
Err(e) => {
eprintln!("Stream error: {:?}", e);
yield Err(actix_web::error::ErrorInternalServerError(e));
}
}
}
// Calculate total time taken
let elapsed_time = start_time.elapsed();
let seconds = elapsed_time.as_secs_f64();
let milliseconds = seconds * 1000.0;
// Choose which time info to send
let time_info = if seconds >= 1.0 {
format!("\n\nTotal Time Taken: {:.2} sec", seconds)
} else {
format!("\n\nTotal Time Taken: {:.2} ms", milliseconds)
};
time_taken = time_info.clone();
yield Ok(Bytes::from(time_info));
};
// actix_web::rt::spawn(async move {
// if let Err(e) = db_conn.transaction::<_, (), DbErr>(|txn| {
// Box::pin(async move {
// insert_messages(&txn, String::from("user_id"), String::from("Testing"), String::from("Test"), String::from("ttt")).await
// })
// }).await {
// eprintln!("Failed to complete transaction: {}", e);
// }
// });
let result = db_conn.transaction::<_, (), DbErr>(|txn| {
Box::pin(async move {
insert_messages(&txn, String::from("user_id"), String::from("Tesing"), String::from("Test"), String::from("T1123")).await
})
}).await;
match result {
Ok(_) => HttpResponse::Ok().content_type("text/plain").streaming(response_stream),
Err(e) => {
eprintln!("Failed to complete transaction: {}", e);
HttpResponse::InternalServerError().finish() // Return 500 if transaction fails
}
}
}
Ok(res) => HttpResponse::InternalServerError()
.body(format!("Request failed with status: {}", res.status())),
Err(e) => {
eprintln!("Request error: {}", e);
HttpResponse::InternalServerError().finish()
}
}
}