Sqlx transaction abstraction

Hi, yall.

So I've been recently working with sqlx and I'm trying to figure out how to abstract queries over connections/transactions. What I wanted to have is a service that manages transactions and a dao that runs the sql statements. I got it to work, however I don't like my abstractions particularly much.

What I would have is a TransactionProvider like PgPool that service can use to execute commands.

pub struct SomeService<Provider> {
    transactionManager: Provider,
}

impl<Provider: TransactionProvider> SomeService<Provider> {
    pub async fn run_raw(&self) {
        run_query(self.transactionManager.connection()).await;
    }

    pub async fn run_in_transaction(&self) {
        let mut transaction = self.transactionManager.transaction().await;
        run_query(transaction.get()).await;
        run_query(transaction.get()).await;
        transaction.commit().await.unwrap();
    }
}

async fn run_query<'a, A>(a: A) -> ()
where
    A: Acquire<'a, Database = Postgres> + Send + Sync,
{
    let mut executor = a.acquire().await.unwrap();
    sqlx::query("select 'hello world'")
        .execute(&mut *executor)
        .await
        .unwrap();
}

This is a simplified version. I would actually have a Dao instead of functions directly, but the idea is the same. I either use connection directly through self.transaction.connection() or in a transaction through self.transactionManager.transaction().

To create this I would use TransactionProvider which is a simple wrapper over PgPool basic operations.

pub trait TransactionProvider {
    type BasicConnection<'a>: Acquire<'a, Database = Postgres> + Send + Sync
    where
        Self: 'a;
    type Transaction: GenericTransaction + Send + Sync;

    fn connection<'a>(&'a self) -> Self::BasicConnection<'a>;

    async fn transaction<'a>(&'a self) -> Self::Transaction;
}

impl TransactionProvider for PgPool {
    type BasicConnection<'a> = &'a PgPool;
    type Transaction = Transaction<'static, Postgres>;

    fn connection<'a>(&'a self) -> Self::BasicConnection<'a> {
        self
    }

    async fn transaction<'a>(&'a self) -> Self::Transaction {
        self.begin().await.unwrap()
    }
}

pub trait GenericTransaction {
    type Connection<'a>: Acquire<'a, Database = Postgres> + Send + Sync
    where
        Self: 'a;

    fn get<'a>(&'a mut self) -> Self::Connection<'a>;

    async fn commit(self) -> Result<(), sqlx::Error>;

    async fn rollback(self) -> Result<(), sqlx::Error>;
}

impl GenericTransaction for Transaction<'static, Postgres> {
    type Connection<'a>
        = &'a mut Self
    where
        Self: 'a;

    fn get<'a>(&'a mut self) -> Self::Connection<'a> {
        self
    }

    async fn commit(self) -> Result<(), sqlx::Error> {
        self.commit().await
    }

    async fn rollback(self) -> Result<(), sqlx::Error> {
        self.rollback().await
    }
}

The issue is that it's a lot of work and it's not the best in terms of DX.

Has anyone figured out a better way? Or a better logic separation that works with sqxl?

What I've done - and what I've been pretty happy with in practice - is to define my data access layer as a set of structs that wrap a transaction, and make it the caller's problem to demarcate that transaction. For ergonomics, I've also set up extension traits on Transaction<…>, so that I can write code that looks like this:

let mut tx = self.db.begin().await?;
let created = tx.sequence().next(created_at).await?;
let channel = tx
    .channels()
    .create(name, &created)
    .await?;
tx.commit().await?;

The corresponding implementation of tx.channels().create(…) looks like:

pub trait Provider {
    fn channels(&mut self) -> Channels;
}

impl Provider for Transaction<'_, Sqlite> {
    fn channels(&mut self) -> Channels {
        Channels(self)
    }
}

pub struct Channels<'t>(&'t mut SqliteConnection);

impl Channels<'_> {
    pub async fn create(&mut self, name: &Name, created: &Instant) -> Result<History, sqlx::Error> {
        let id = Id::generate();
        let name = name.clone();
        let display_name = name.display();
        let canonical_name = name.canonical();
        let created = *created;

        sqlx::query!(
            r#"-- query omitted for brevity"#,
            id,
            created.at,
            created.sequence,
            created.sequence,
        )
        .execute(&mut *self.0)
        .await?;

        sqlx::query!(
            r#"-- query omitted for brevity"#,
            id,
            display_name,
            canonical_name,
        )
        .execute(&mut *self.0)
        .await?;

        let channel = todo!();
        Ok(channel)
    }

}

sqlx's own Transaction type provides fairly good control over transaction lifecycle, and I haven't found any burning need in my own projects to wrap it in my own type. You might, but I'd encourage you to find the actual need first and define the wrapper based on that need, rather than designing GenericTransaction up front when all it's doing is forwarding calls to Transaction.

Interesting. The reason I wanted to create an abstraction is to separate logic from database connection so I can have unit tests.
I also thought about passing a DaoBuilder that would consume a connection, but decided that it would be too difficult to do in a generic way.

1 Like

That's entirely valid. It's a goal I've consciously abdicated from in this project - as you can see, it uses sqlite rather than postgres, so testing against an in-memory database is a cheap and low-risk alternative in a way it wouldn't be when the database is on the far side of an inter-process communication channel.

So, I've been refining the approach. @derspiny's usage of PgConnection gave me an idea that transactions and basic connections can be generalized through that type. So I updated the code to make it nicer to work with.

pub trait TransactionProvider: Send + Sync + 'static {
    type Connection;
    type Transaction: GenericTransaction<Connection = Self::Connection> + Send + Sync + 'static;

    fn connection(
        &self,
    ) -> impl Future<Output = sqlx::Result<impl AsMut<Self::Connection> + Send + Sync + 'static>> + Send;

    fn transaction(&self) -> impl Future<Output = sqlx::Result<Self::Transaction>> + Send;
}

impl TransactionProvider for PgPool {
    type Connection = PgConnection;
    type Transaction = Transaction<'static, Postgres>;

    async fn connection(
        &self,
    ) -> sqlx::Result<impl AsMut<Self::Connection> + Send + Sync + 'static> {
        Ok(self.acquire().await?)
    }

    async fn transaction(&self) -> sqlx::Result<Self::Transaction> {
        self.begin().await
    }
}

pub trait GenericTransaction: AsMut<Self::Connection> {
    type Connection;

    fn commit(self) -> impl std::future::Future<Output = Result<(), sqlx::Error>> + Send;

    fn rollback(self) -> impl std::future::Future<Output = Result<(), sqlx::Error>> + Send;
}

impl GenericTransaction for Transaction<'static, Postgres> {
    type Connection = PgConnection;

    async fn commit(self) -> Result<(), sqlx::Error> {
        self.commit().await
    }

    async fn rollback(self) -> Result<(), sqlx::Error> {
        self.rollback().await
    }
}

This is a slightly refined abstraction that also handles issues with "Cannot send between threads safely".
Relying on a singular type Connection I have a significantly easier time defining daos and services. In daos I can simply specify Connection = PgConnection and accept it as a parameter for methods. In services I only need to specify generic constraints for TransactionManager and Dao.

// Define service and dao without any particular generics other than Connection.
#[async_trait]
trait FooService {
    async fn run_raw(&self) -> ();

    async fn run_in_transaction(&self) -> ();
}

#[async_trait]
trait FooDao: Send + Sync + 'static {
    type Connection: Send + Sync + 'static;
    async fn run_query(&self, connection: &mut Self::Connection) -> ();
}

pub struct FooServiceImpl<Provider, DaoB> {
    transaction_manager: Provider,
    test_dao: DaoB,
}

#[async_trait]
impl<Provider, Dao> FooService for FooServiceImpl<Provider, Dao>
where
    // During implementation we need to ensure that Provider::Connection is the same one as
    // Dao::Connection accepts.
    Provider: TransactionProvider,
    Dao: FooDao<Connection = Provider::Connection>,
    Provider::Connection: Send + Sync + 'static,
{
    async fn run_raw(&self) {
        let mut connection = self.transaction_manager.connection().await.unwrap();
        self.test_dao.run_query(connection.as_mut()).await;
        self.test_dao.run_query(connection.as_mut()).await;
    }

    async fn run_in_transaction(&self) {
        let mut transaction = self.transaction_manager.transaction().await.unwrap();
        self.test_dao.run_query(transaction.as_mut()).await;
        self.test_dao.run_query(transaction.as_mut()).await;
        transaction.commit().await.unwrap();
    }
}

struct FooDaoImpl;

#[async_trait]
impl FooDao for FooDaoImpl {
    type Connection = PgConnection;

    async fn run_query(&self, connection: &mut Self::Connection) -> () {
        sqlx::query("select 'hello world'")
            .execute(connection)
            .await
            .unwrap();
    }
}

The advantage of this approach are:

  1. Less boilerplate: no <'a, A> ... where A: Acquire....
  2. Ability to make dyn-compatible daos.