Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,28 +109,40 @@ enum Command {
Shutdown(Box<dyn FnOnce(Result<(), Error>) + Send>),
}

fn run_catching<F, T>(func: F) -> Result<T, Error>
fn run_catching<F, T>(conn: &mut Connection, func: F) -> Result<T, Error>
where
F: FnOnce() -> Result<T, rusqlite::Error>,
F: FnOnce(&mut Connection) -> Result<T, rusqlite::Error>,
{
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(func)) {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
Ok(res) => res.map_err(Error::from),
Err(p) => Err(Error::Panic {
message: panic_message(&*p),
}),
Err(p) => {
rollback_if_needed(conn);
Err(Error::Panic {
message: panic_message(&*p),
})
}
}
}

fn run_catching_and_then<F, T, E>(func: F) -> Result<T, E>
fn run_catching_and_then<F, T, E>(conn: &mut Connection, func: F) -> Result<T, E>
where
F: FnOnce() -> Result<T, E>,
F: FnOnce(&mut Connection) -> Result<T, E>,
E: From<Error>,
{
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(func)) {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| func(conn))) {
Ok(res) => res,
Err(p) => Err(E::from(Error::Panic {
message: panic_message(&*p),
})),
Err(p) => {
rollback_if_needed(conn);
Err(E::from(Error::Panic {
message: panic_message(&*p),
}))
}
}
}

fn rollback_if_needed(conn: &mut Connection) {
if !conn.is_autocommit() {
let _ = conn.execute_batch("ROLLBACK");
}
}

Expand Down Expand Up @@ -236,7 +248,7 @@ impl Client {
{
let (tx, rx) = oneshot::channel();
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching(|| func(conn)));
_ = tx.send(run_catching(conn, |conn| func(conn)));
})))?;
rx.await?
}
Expand All @@ -249,7 +261,7 @@ impl Client {
{
let (tx, rx) = oneshot::channel();
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching(|| func(conn)));
_ = tx.send(run_catching(conn, func));
})))?;
rx.await?
}
Expand All @@ -267,7 +279,7 @@ impl Client {
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching_and_then(|| func(conn)));
_ = tx.send(run_catching_and_then(conn, |conn| func(conn)));
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
Expand All @@ -286,7 +298,7 @@ impl Client {
let (tx, rx) = oneshot::channel();
self.conn_tx
.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching_and_then(|| func(conn)));
_ = tx.send(run_catching_and_then(conn, func));
})))
.map_err(Error::from)?;
rx.await.map_err(Error::from)?
Expand Down Expand Up @@ -316,7 +328,7 @@ impl Client {
{
let (tx, rx) = bounded(1);
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching(|| func(conn)));
_ = tx.send(run_catching(conn, |conn| func(conn)));
})))?;
rx.recv()?
}
Expand All @@ -330,7 +342,7 @@ impl Client {
{
let (tx, rx) = bounded(1);
self.conn_tx.send(Command::Func(Box::new(move |conn| {
_ = tx.send(run_catching(|| func(conn)));
_ = tx.send(run_catching(conn, func));
})))?;
rx.recv()?
}
Expand Down
50 changes: 50 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ async_test!(test_pool_conn_for_each);
async_test!(test_pool_close_concurrent);
async_test!(test_pool_num_conns_zero_clamps);
async_test!(test_closure_panic_surfaces_error);
async_test!(test_panic_after_begin_immediate_rolls_back);

async fn test_journal_mode() {
for (journal_mode, expected) in journal_modes() {
Expand Down Expand Up @@ -565,6 +566,55 @@ async fn test_closure_panic_surfaces_error() {
client.close().await.expect("closing client");
}

async fn test_panic_after_begin_immediate_rolls_back() {
let tmp_dir = tempfile::tempdir().unwrap();
let db_path = tmp_dir.path().join("sqlite.db");
let client = ClientBuilder::new()
.path(&db_path)
.open()
.await
.expect("client unable to be opened");

client
.conn(|conn| {
conn.execute(
"CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)",
(),
)?;
Ok(())
})
.await
.expect("creating table");

let res: Result<(), Error> = client
.conn(|conn| {
conn.execute_batch("BEGIN IMMEDIATE")?;
conn.execute("INSERT INTO testing VALUES (1, ?)", ["panic"])?;
panic!("boom after BEGIN IMMEDIATE");
})
.await;
match res {
Err(Error::Panic { message }) => assert!(message.contains("boom"), "got {message}"),
other => panic!("expected Error::Panic, got {other:?}"),
}

let row_count: i64 = client
.conn(|conn| conn.query_row("SELECT COUNT(*) FROM testing", (), |row| row.get(0)))
.await
.expect("counting rows after rollback");
assert_eq!(row_count, 0);

let other = rusqlite::Connection::open(&db_path).expect("opening second connection");
other
.busy_timeout(std::time::Duration::from_millis(0))
.expect("setting busy timeout");
other
.execute("INSERT INTO testing VALUES (2, ?)", ["other"])
.expect("second connection can write after panic rollback");

client.close().await.expect("closing client");
}

async fn test_pool_num_conns_zero_clamps() {
let tmp_dir = tempfile::tempdir().unwrap();
let pool = PoolBuilder::new()
Expand Down
Loading