I want to save the response to the database after streaming the entire response to the client

I want to save the response to the database after streaming the entire response to the client. If there's an error while saving the data, I want the server to return a 500 error. I understand that changing the HTTP response after sent is not possible, but I don't want the streaming to complete before inserting the message into the database. I'm looking for a simple way to achieve this, as I'm still learning and experimenting.

The current code saves to the database before sending the response. If I use actix_web::rt::spawn, it sends the first chunk but then waits until the database save operation is complete before continuing with the rest of the response.


pub async fn groq_chat<'a>(message: &str, db_conn: &DatabaseConnection) -> HttpResponse {
    let groq_api = env::var("GROQ_API").expect("GROQ_API must be set");
    let client = Client::new();
    let start_time = Instant::now();
    let mut time_taken = String::new();

    

    match client
        .post("https://api.groq.com/openai/v1/chat/completions")
        .bearer_auth(groq_api)
        .header("Content-Type", "application/json")
        .json(&json!({
            "messages": [
                {"role": "user", "content": message}
            ],
            "model": "mixtral-8x7b-32768",
            "stream": true
        }))
        .send()
        .await
    {
        Ok(res) if res.status().is_success() => {
            let mut content_buffer = String::new();
            let mut stream = res.bytes_stream();
            let mut accumulator = String::new();

            let response_stream = stream! {
                while let Some(item) = stream.next().await {
                    match item {
                        Ok(bytes) => {
                            let s = match std::str::from_utf8(&bytes) {
                                Ok(v) => v,
                                Err(e) => {
                                    eprintln!("Invalid UTF-8 sequence: {}", e);
                                    continue;
                                }
                            };

                            accumulator.push_str(s);

                            while let Some(pos) = accumulator.find("\n\n") {
                                let  test = accumulator.clone();
                                let (chunk, rest) = test.split_at(pos);
                                accumulator = rest.trim_start_matches("\n\n").to_string();

                                if let Some(json_str) = chunk.strip_prefix("data: ") {
                                    if json_str == "[DONE]" {
                                        break;
                                    }

                                    match serde_json::from_str::<GroqChatResponse>(json_str) {
                                        Ok(chat_response) => {
                                            for choice in chat_response.choices {
                                                if let Some(content) = choice.delta.content {
                                                    content_buffer.push_str(&content);
                                                }
                                            }
                                            yield Ok(Bytes::from(content_buffer.clone()));
                                            content_buffer.clear();
                                        }
                                        Err(e) => {
                                            eprintln!("Failed to parse JSON: {}", e);
                                        }
                                    }
                                }
                            }
                        }
                        Err(e) => {
                            eprintln!("Stream error: {:?}", e);
                            yield Err(actix_web::error::ErrorInternalServerError(e));
                        }
                    }
                }
                 // Calculate total time taken
                 let elapsed_time = start_time.elapsed();
                 let seconds = elapsed_time.as_secs_f64();
                 let milliseconds = seconds * 1000.0;

                 // Choose which time info to send
                 let time_info = if seconds >= 1.0 {
                     format!("\n\nTotal Time Taken: {:.2} sec", seconds)
                 } else {
                     format!("\n\nTotal Time Taken: {:.2} ms", milliseconds)
                 };

                 time_taken = time_info.clone();

                 yield Ok(Bytes::from(time_info));
            };
          
            // actix_web::rt::spawn(async move {
            //     if let Err(e) = db_conn.transaction::<_, (), DbErr>(|txn| {
            //         Box::pin(async move {
            //             insert_messages(&txn, String::from("user_id"), String::from("Testing"), String::from("Test"), String::from("ttt")).await
            //         })
            //     }).await {
            //         eprintln!("Failed to complete transaction: {}", e);
            //     }
            // });

            let result = db_conn.transaction::<_, (), DbErr>(|txn| {
                Box::pin(async move {
                    insert_messages(&txn, String::from("user_id"), String::from("Tesing"), String::from("Test"), String::from("T1123")).await
                })
            }).await;


            match result {
                Ok(_) => HttpResponse::Ok().content_type("text/plain").streaming(response_stream),
                Err(e) => {
                    eprintln!("Failed to complete transaction: {}", e);
                    HttpResponse::InternalServerError().finish() // Return 500 if transaction fails
                }
            }

        }
        Ok(res) => HttpResponse::InternalServerError()
            .body(format!("Request failed with status: {}", res.status())),
        Err(e) => {
            eprintln!("Request error: {}", e);
            HttpResponse::InternalServerError().finish()
        }
    }
}

I don't want the streaming to complete before inserting the message into the database.

Consider inserting the message into your database up front, and then streaming it to the client (either from memory or back out of the database) afterwards. If you need the client to receive an error status when saving the response fails, this is pretty much the only way to do it, as you need to know whether saving the response has succeeded or not before you send the status line, which in turn must be sent before the first byte of the response.

This is a protocol issue; there's nothing clever you can do on the server side to avoid it, so keep it simple.

I've done something similar before. But I use axum's Sse and tokio [ReceiverStream].(tokio_stream::wrappers - Rust)

let open_ai_stream = openai.chat_stream(chat_id, &query).await?;
let (tx, rx) = mpsc::unbounded_channel::<ChatData>();

let receiver_stream = tokio_stream::wrappers::UnboundedReceiverStream::new(rx);

let out_sse = Sse::new(receiver_stream).keep_alive(
        KeepAlive::default()
            .text("keep-alive")
            .interval(Duration::from_millis(100)),
    );

    Ok(out_sse)

----
// a bit pseudo codey down here
fn process_ai_stream(ai_stream, tx) {
    tokio::spawn(async move {
    // process the ai stream in different thread
       while let Some(stream_event) = ai_stream.next().await {
          tx.send(some data from ai stream event)
       }

       // out of the loop here, there must be no more data in stream, let's save to DB
    })
}

Alternatively, the other way I had solved this before was to wrap the stream in some Struct and then impl Stream for that Struct. The impl for Stream for your wrapper should be somewhat simple since you're just gonna call poll on the underlying ai stream.

let mut ai_stream = &self.ai_stream;
match Pin::new(chat_stream).poll_next(cx) {
    Poll::Ready(None) => {
       .. The stream is done here so save to DB
    }
}

I hope this helps.

2 Likes

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.