From 7a078cee210678a31fa8a390525b31db797010c1 Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Tue, 26 May 2026 22:27:09 -0700 Subject: [PATCH] Rollback open transactions after caught panics --- src/client.rs | 48 ++++++++++++++++++++++++++++++------------------ tests/tests.rs | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 18 deletions(-) diff --git a/src/client.rs b/src/client.rs index fcab04c..db44771 100644 --- a/src/client.rs +++ b/src/client.rs @@ -109,28 +109,40 @@ enum Command { Shutdown(Box) + Send>), } -fn run_catching(func: F) -> Result +fn run_catching(conn: &mut Connection, func: F) -> Result where - F: FnOnce() -> Result, + F: FnOnce(&mut Connection) -> Result, { - match std::panic::catch_unwind(std::panic::AssertUnwindSafe(func)) { + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) { Ok(res) => res.map_err(Error::from), - Err(p) => Err(Error::Panic { - message: panic_message(&*p), - }), + Err(p) => { + rollback_if_needed(conn); + Err(Error::Panic { + message: panic_message(&*p), + }) + } } } -fn run_catching_and_then(func: F) -> Result +fn run_catching_and_then(conn: &mut Connection, func: F) -> Result where - F: FnOnce() -> Result, + F: FnOnce(&mut Connection) -> Result, E: From, { - match std::panic::catch_unwind(std::panic::AssertUnwindSafe(func)) { + match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) { Ok(res) => res, - Err(p) => Err(E::from(Error::Panic { - message: panic_message(&*p), - })), + Err(p) => { + rollback_if_needed(conn); + Err(E::from(Error::Panic { + message: panic_message(&*p), + })) + } + } +} + +fn rollback_if_needed(conn: &mut Connection) { + if !conn.is_autocommit() { + let _ = conn.execute_batch("ROLLBACK"); } } @@ -236,7 +248,7 @@ impl Client { { let (tx, rx) = oneshot::channel(); self.conn_tx.send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching(|| func(conn))); + _ = tx.send(run_catching(conn, |conn| func(conn))); })))?; rx.await? } @@ -249,7 +261,7 @@ impl Client { { let (tx, rx) = oneshot::channel(); self.conn_tx.send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching(|| func(conn))); + _ = tx.send(run_catching(conn, func)); })))?; rx.await? } @@ -267,7 +279,7 @@ impl Client { let (tx, rx) = oneshot::channel(); self.conn_tx .send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching_and_then(|| func(conn))); + _ = tx.send(run_catching_and_then(conn, |conn| func(conn))); }))) .map_err(Error::from)?; rx.await.map_err(Error::from)? @@ -286,7 +298,7 @@ impl Client { let (tx, rx) = oneshot::channel(); self.conn_tx .send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching_and_then(|| func(conn))); + _ = tx.send(run_catching_and_then(conn, func)); }))) .map_err(Error::from)?; rx.await.map_err(Error::from)? @@ -316,7 +328,7 @@ impl Client { { let (tx, rx) = bounded(1); self.conn_tx.send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching(|| func(conn))); + _ = tx.send(run_catching(conn, |conn| func(conn))); })))?; rx.recv()? } @@ -330,7 +342,7 @@ impl Client { { let (tx, rx) = bounded(1); self.conn_tx.send(Command::Func(Box::new(move |conn| { - _ = tx.send(run_catching(|| func(conn))); + _ = tx.send(run_catching(conn, func)); })))?; rx.recv()? } diff --git a/tests/tests.rs b/tests/tests.rs index 02cfb67..31c1186 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -215,6 +215,7 @@ async_test!(test_pool_conn_for_each); async_test!(test_pool_close_concurrent); async_test!(test_pool_num_conns_zero_clamps); async_test!(test_closure_panic_surfaces_error); +async_test!(test_panic_after_begin_immediate_rolls_back); async fn test_journal_mode() { for (journal_mode, expected) in journal_modes() { @@ -565,6 +566,55 @@ async fn test_closure_panic_surfaces_error() { client.close().await.expect("closing client"); } +async fn test_panic_after_begin_immediate_rolls_back() { + let tmp_dir = tempfile::tempdir().unwrap(); + let db_path = tmp_dir.path().join("sqlite.db"); + let client = ClientBuilder::new() + .path(&db_path) + .open() + .await + .expect("client unable to be opened"); + + client + .conn(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)", + (), + )?; + Ok(()) + }) + .await + .expect("creating table"); + + let res: Result<(), Error> = client + .conn(|conn| { + conn.execute_batch("BEGIN IMMEDIATE")?; + conn.execute("INSERT INTO testing VALUES (1, ?)", ["panic"])?; + panic!("boom after BEGIN IMMEDIATE"); + }) + .await; + match res { + Err(Error::Panic { message }) => assert!(message.contains("boom"), "got {message}"), + other => panic!("expected Error::Panic, got {other:?}"), + } + + let row_count: i64 = client + .conn(|conn| conn.query_row("SELECT COUNT(*) FROM testing", (), |row| row.get(0))) + .await + .expect("counting rows after rollback"); + assert_eq!(row_count, 0); + + let other = rusqlite::Connection::open(&db_path).expect("opening second connection"); + other + .busy_timeout(std::time::Duration::from_millis(0)) + .expect("setting busy timeout"); + other + .execute("INSERT INTO testing VALUES (2, ?)", ["other"]) + .expect("second connection can write after panic rollback"); + + client.close().await.expect("closing client"); +} + async fn test_pool_num_conns_zero_clamps() { let tmp_dir = tempfile::tempdir().unwrap(); let pool = PoolBuilder::new()