Connection pool for HTTP Reverse Proxy server

HTTP/1.1 reverse proxy basing on hyper v1.0-rc1 and this example.

I want to reuse the TCP connections if possible. For this purpose, a TcpStreamPool instance is created for each backend. When connection is used up, the inner TcpStream is added back into pool so that it can be reused for future requests.

The code works in normal case, but if the previous HTTP request/response does not complete on the wire (e.g. partial request or response), it would cause error for future request using this connection.

Question is: how to determine if hyper completes request/response entriely before recycling the connection back into pool?


lazy_static! {
  static ref POOL_MAP: RwLock<HashMap<SocketAddr, Arc<TcpStreamPool>>> = {
      RwLock::new(HashMap::new())
  };
}

struct TcpStreamPool {
  address : SocketAddr,
  queue : ArrayQueue<TcpStream>,
}

struct ReusableTcpStream{
  inner : Option<TcpStream>,
  pool : Weak<TcpStreamPool>
}



use crate::cluster::BackendRoutine;

pub async fn run<F>(port : u16, cb : F) -> Result<(), Box<dyn std::error::Error + Send + Sync>>
  where F : Fn(&Request<Incoming>) -> Option<BackendRoutine> + Send + Clone + 'static 
{

    let addr: SocketAddr = ([0, 0, 0, 0], port).into();

    let listener = TcpListener::bind(addr).await?;
    info!("Listening on http://{}", addr);


    loop {
      match listener.accept().await {
        Ok((stream, _)) => {
          let cloned_callback = cb.clone();

          tokio::task::spawn(async move {
              if let Err(err) = http1::Builder::new()
                  .serve_connection(stream, Svc::new(cloned_callback))
                  .await
              {
                  warn!("Failed to servce connection: {:?}", err);
              }
          });
        },
        Err(e) => {
          error!("Failed to accept new connection. {}", e);
        }
      }
    }
}




struct Svc<F> 
  where F : Fn(&Request<Incoming>) -> Option<BackendRoutine> + Send + Clone + 'static 
{
  response_header : HeaderValue,
  callback : F,
}

impl<F> Svc<F>
  where F : Fn(&Request<Incoming>) -> Option<BackendRoutine> + Send + Clone + 'static {
    fn new(cb : F) -> Self {
      Self {
        response_header : HeaderValue::from_str(gethostname::gethostname().to_str().unwrap()).unwrap(),
        callback : cb
      }
    }
}

impl<F> Service<Request<Incoming>> for Svc<F> 
  where F : Fn(&Request<Incoming>) -> Option<BackendRoutine> + Send + Clone + 'static
{
  type Response = Response<GatewayBody>;
  type Error = hyper::Error;
  type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;

  fn call(&mut self, req: Request<Incoming>) -> Self::Future {


      fn mk_response(status_code : http::StatusCode) -> Result<Response<GatewayBody>, hyper::Error> {
        let res = Response::builder()
          .status(status_code)
          .header("X-Powered-By", HeaderValue::from_str(gethostname::gethostname().to_str().unwrap()).unwrap())
          .body(GatewayBody::Empty)
          .unwrap();
        Ok(res)
      }

      async fn connect_backend(addr : SocketAddr) -> Option<SendRequest<Incoming>> {
        match TcpStreamPool::get(addr).await.connect().await {
          Ok(client_stream) => {
            match hyper::client::conn::http1::handshake(client_stream).await {
              Ok((sender, mut conn)) => {
                tokio::task::spawn(async move {
                  if let Err(err) = poll_fn(|cx| conn.poll_without_shutdown(cx)).await {
                      warn!("Connection failed: {:?}", err);
                  } else {
                    drop(conn.into_parts().io.recycle());
                  }
                });
                Some(sender)
              }
              Err(e) => {
                warn!("Handshake with {} failed. {}", addr, e);
                None
              }
            }
          }
          Err(e) => {
            warn!("Failed to establish connection to {}. {}", addr, e);
            None
          }
        }
      }

      let cloned_header = self.response_header.clone();

      let backend_routine_option = (self.callback)(&req);
      
      Box::pin(async move {
          
          if let Some(backend_routine) = backend_routine_option {

            // first try primary backend
            let primary_backend = backend_routine.primary();

            let option = match connect_backend(primary_backend.address()).await {
              Some(s) => Some(s),
              None => {
                warn!("Unable to connect to primary backend {}", primary_backend.node().hostname());
                if let Some(secondary_backend) = backend_routine.secondary() {
                  match connect_backend(secondary_backend.address()).await {
                    Some(s) => Some(s),
                    None => {
                      warn!("Unable to connect to secondary backend {}", secondary_backend.node().hostname());
                      None
                    }
                  }
                } else {
                  None
                }
              }
            };
 
            if let Some(mut sender) = option {
              sender.send_request(req).await.and_then(|upstream_resp| {

                let mut res = Response::builder().status(upstream_resp.status());
                {
                  let headers = res.headers_mut().unwrap();
                  headers.clone_from(upstream_resp.headers());
                  headers.append("X-Powered-By", cloned_header);
                  headers.append("Server", HeaderValue::from_str(primary_backend.node().hostname()).unwrap());
                }
                Ok ( res.body(GatewayBody::Incoming(upstream_resp.into_body())).unwrap() )
              })
            } else {
              mk_response(StatusCode::BAD_GATEWAY)
            }

            
          } else {

            mk_response(StatusCode::BAD_GATEWAY)
          }
      }) 
  }
}


// an enum to support different concrete body types
enum GatewayBody {
  Incoming(hyper::body::Incoming),
  Empty,
}


impl hyper::body::Body for GatewayBody
{
    type Data = hyper::body::Bytes;
    type Error = hyper::Error;

    fn poll_frame(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
      match &mut *self.get_mut(){
        Self::Incoming(incoming) => {
          Pin::new(incoming).poll_frame(cx)
        },
        Self::Empty => {
          Poll::Ready(None)
        }
      }
      
    }
}


impl TcpStreamPool {
    fn new(addr : SocketAddr) -> Self {
        Self {
          queue : ArrayQueue::new(100),
          address : addr
        }
    }
}

impl TcpStreamPool {

  async fn get(addr : SocketAddr) -> Arc<TcpStreamPool> {
    
    loop {
      if let Some(queue) = POOL_MAP.read().await.get(&addr) {
        return queue.clone();
      }
      _ = POOL_MAP.write().await.insert(addr, Arc::new(TcpStreamPool::new(addr)));
    }
  }

    async fn connect(self : &Arc<Self>) -> Result<ReusableTcpStream, std::io::Error> {
      println!("{}", self.queue.len());
      loop {
        match self.queue.pop() {
          Some(inner) => {
            // determine if stream is not closed
            let mut empty_buf = [0u8;0];
            if inner.try_write(&mut empty_buf).is_ok() {
              return Ok(ReusableTcpStream {
                inner : Some(inner),
                pool : Arc::downgrade(self)
              });
            }
          },
          None => {
            let inner = TcpStream::connect(self.address).await?;
 
            return Ok(ReusableTcpStream {
              inner : Some(inner),
              pool : Arc::downgrade(self)
            });
          }
        }
      }
 
    }
}




impl ReusableTcpStream {
  fn recycle(&mut self) -> Result<(), TcpStream> {
      if let Some(inner) = self.inner.take() {
        if let Some(pool) = self.pool.upgrade() {
          return pool.queue.push(inner);
        }
        return Err(inner);
      }
      Ok(())
  }
}

 
// ===== impl Read / Write =====

impl AsyncRead for ReusableTcpStream {
  fn poll_read(
      mut self: Pin<&mut Self>,
      cx: &mut Context<'_>,
      buf: &mut ReadBuf<'_>,
  ) -> Poll<io::Result<()>> {
      Pin::new(self.inner.as_mut().unwrap()).poll_read(cx, buf)
  }
}

impl AsyncWrite for ReusableTcpStream {
  fn poll_write(
      mut self: Pin<&mut Self>,
      cx: &mut Context<'_>,
      buf: &[u8],
  ) -> Poll<io::Result<usize>> {
    Pin::new(self.inner.as_mut().unwrap()).poll_write(cx, buf)
  }

  fn poll_write_vectored(
      mut self: Pin<&mut Self>,
      cx: &mut Context<'_>,
      bufs: &[std::io::IoSlice<'_>],
  ) -> Poll<io::Result<usize>> {
    Pin::new(self.inner.as_mut().unwrap()).poll_write_vectored(cx, bufs)
  }

  fn is_write_vectored(&self) -> bool {
    // aways supported
    true
  }

  #[inline]
  fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
    // tcp flush is a no-op
    Poll::Ready(Ok(()))
  }

  #[inline]
  fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
    Poll::Ready(Ok(()))
  }
}