Logging Axum's Response and Request info?

I'm starting to grok the Axum custom extractors but I am not seeing any examples that produce the following bits of information that I would like to log to our observability platform:

  1. The latency
  2. The fully qualified URL
  3. The status code

It seems that the code for middleware returns the Response object, which does not have the original URL:


 TraceLayer::new_for_http()
                .on_response(|_response: &Response, _latency: Duration, _span: &Span| {
                    // ...
                })
                .on_failure(
                    |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
                        // ...
                    },
                )

I'm sure I'm missing something simple here, but any help is appreciated!

You could get inspiration from other Tower services such as trace.

Two options off the top of my head:

  1. Put your own middleware in that stashes the request URI in the response extensions (you'll have to be careful with the middleware ordering though)
  2. Just use TraceLayers on_request handler to stick the URI in an Arc<Mutex<Option<...>>> shared between it and the on_response handler

That sounds like what I need, but do you have an example of that anywhere?

Option 1

use std::time::Duration;

use axum::{
    http::{Request, Uri},
    middleware::Next,
    response::Response,
    routing::get,
    Router,
};
use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
use tracing::Span;

#[tokio::main]
async fn main() {
    // build our application with a single route
    let app = Router::new()
        .route("/", get(|| async { "Hello, World!" }))
        .layer(axum::middleware::from_fn(uri_middleware))
        .layer(
            TraceLayer::new_for_http()
                .on_response(|response: &Response, _latency: Duration, _span: &Span| {
                    println!(
                        "{:?}",
                        response.extensions().get::<RequestUri>().map(|r| &r.0)
                    )
                })
                .on_failure(
                    |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
                        // ...
                    },
                ),
        );

    // run it with hyper on localhost:3000
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

struct RequestUri(Uri);

async fn uri_middleware<B>(request: Request<B>, next: Next<B>) -> Response {
    let uri = request.uri().clone();

    let mut response = next.run(request).await;

    response.extensions_mut().insert(RequestUri(uri));

    response
}

I realized my second option didn't really make sense with the way TraceLayer is set up, there's not an easy way to create a mutex per request there.

I think you might be able to just fill in the request URI on the Span in the appropriate callback and then it would apply for all of the messages logged in the layer.


I got a little sidetracked implementing another way to get the URI from all of the TraceLayer's callbacks via a manual future and a thread local. It works, but it may have bugs. Either way it's probably overkill

Code
use std::{cell::RefCell, marker::PhantomData, time::Duration};

use axum::{
    http::{Request, Uri},
    middleware::Next,
    response::Response,
    routing::get,
    Router,
};
use futures::Future;
use tower_http::{classify::ServerErrorsFailureClass, trace::TraceLayer};
use tracing::Span;

#[tokio::main]
async fn main() {
    // build our application with a single route
    let app = Router::new()
        .route("/", get(|| async { "Hello, World!" }))
        .layer(
            TraceLayer::new_for_http()
                .on_response(|_response: &Response, _latency: Duration, _span: &Span| {
                    println!("{:?}", get_request_uri())
                })
                .on_failure(
                    |_error: ServerErrorsFailureClass, _latency: Duration, _span: &Span| {
                        println!("{:?}", get_request_uri())
                    },
                ),
        )
        .layer(axum::middleware::from_fn(uri_middleware));

    // run it with hyper on localhost:3000
    axum::Server::bind(&"0.0.0.0:3000".parse().unwrap())
        .serve(app.into_make_service())
        .await
        .unwrap();
}

fn uri_middleware<B>(request: Request<B>, next: Next<B>) -> impl Future<Output = Response> {
    let uri = request.uri().clone();

    UriFuture {
        inner: next.run(request),
        uri: Some(uri),
    }
}

pin_project_lite::pin_project! {
    /// A future that sets the `REQUEST_URI` thread local when polled, before calling the inner future's poll
    struct UriFuture<F> {
        #[pin]
        inner: F,
        uri: Option<Uri>,
    }
}

impl<F> Future for UriFuture<F>
where
    F: Future,
{
    type Output = F::Output;

    fn poll(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Self::Output> {
        let this = self.project();
        // The guard will reset the thread local to its previous state when it gets dropped.
        let _guard = UriGuard::enter(this.uri);
        this.inner.poll(cx)
    }
}

thread_local! {
    static REQUEST_URI: RefCell<Option<Uri>> = RefCell::new(None);
}

fn get_request_uri() -> Option<Uri> {
    REQUEST_URI.with(|uri| uri.borrow().clone())
}

struct UriGuard<'a> {
    src: &'a mut Option<Uri>,
    old: Option<Uri>,
    /// Don't want the type to be `Send` or `Sync` since it's manipulating a thread local.
    _phantom: PhantomData<*const ()>,
}

impl<'a> UriGuard<'a> {
    fn enter(src: &'a mut Option<Uri>) -> Self {
        REQUEST_URI.with(|uri| {
            let mut uri = uri.borrow_mut();

            let old = uri.take();

            *uri = src.take();

            Self {
                src,
                old,
                _phantom: PhantomData,
            }
        })
    }
}

impl Drop for UriGuard<'_> {
    fn drop(&mut self) {
        REQUEST_URI.with(|uri| {
            let mut uri = uri.borrow_mut();

            *self.src = uri.take();

            *uri = self.old.take();
        })
    }
}
1 Like

Option 1 works perfectly for my needs. Thanks!

and sorry for being late to the party

This topic was automatically closed 90 days after the last reply. We invite you to open a new topic if you have further questions or comments.