diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 2815802..0debcb0 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1205,7 +1205,7 @@ impl Conn { } /// Requires that `self.inner.tx_status != TxStatus::None` - async fn rollback_transaction(&mut self) -> Result<()> { + pub(crate) async fn rollback_transaction(&mut self) -> Result<()> { debug_assert_ne!(self.inner.tx_status, TxStatus::None); self.inner.tx_status = TxStatus::None; self.query_drop("ROLLBACK").await diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index ebef366..870abe4 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -96,8 +96,7 @@ impl Conn { pub(crate) async fn clean_dirty(&mut self) -> Result<()> { self.drop_result().await?; if self.get_tx_status() == TxStatus::RequiresRollback { - self.set_tx_status(TxStatus::None); - self.exec_drop("ROLLBACK", ()).await?; + self.rollback_transaction().await?; } Ok(()) } diff --git a/src/queryable/transaction.rs b/src/queryable/transaction.rs index f4956db..7346d39 100644 --- a/src/queryable/transaction.rs +++ b/src/queryable/transaction.rs @@ -143,6 +143,8 @@ impl<'a> Transaction<'a> { let mut conn = conn.into(); + conn.clean_dirty().await?; + if conn.get_tx_status() != TxStatus::None { return Err(DriverError::NestedTransaction.into()); } @@ -188,8 +190,7 @@ impl<'a> Transaction<'a> { match self.try_commit().await { Ok(..) => Ok(()), Err(e) => { - self.0.query_drop("ROLLBACK").await.unwrap_or(()); - self.0.set_tx_status(TxStatus::None); + self.0.rollback_transaction().await.unwrap_or(()); Err(e) } } @@ -197,10 +198,7 @@ impl<'a> Transaction<'a> { /// Performs `ROLLBACK` query. pub async fn rollback(mut self) -> Result<()> { - let result = self.0.query_iter("ROLLBACK").await?; - result.drop_result().await?; - self.0.set_tx_status(TxStatus::None); - Ok(()) + self.0.rollback_transaction().await } }