diff --git a/src/client.rs b/src/client.rs index 26cd76b..90c0632 100644 --- a/src/client.rs +++ b/src/client.rs @@ -476,6 +476,40 @@ impl Client { rx.recv()? } + /// Invokes the provided function with a [`rusqlite::Connection`], + /// blocking the current thread until completion. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub fn conn_and_then_blocking(&self, func: F) -> Result + where + F: FnOnce(&Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + let rx = self + .enqueue_blocking(move |conn| run_catching_and_then(conn, |conn| func(conn))) + .map_err(Error::from)?; + rx.recv().map_err(Error::from)? + } + + /// Invokes the provided function with a mutable [`rusqlite::Connection`], + /// blocking the current thread until completion. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub fn conn_mut_and_then_blocking(&self, func: F) -> Result + where + F: FnOnce(&mut Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + let rx = self + .enqueue_blocking(move |conn| run_catching_and_then(conn, func)) + .map_err(Error::from)?; + rx.recv().map_err(Error::from)? + } + /// Closes the underlying sqlite connection, blocking the current thread /// until complete. /// diff --git a/src/pool.rs b/src/pool.rs index dde8a28..32de5a7 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -334,6 +334,32 @@ impl Pool { self.get().conn_mut(func).await } + /// Invokes the provided function with a [`rusqlite::Connection`]. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub async fn conn_and_then(&self, func: F) -> Result + where + F: FnOnce(&Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + self.get().conn_and_then(func).await + } + + /// Invokes the provided function with a mutable [`rusqlite::Connection`]. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub async fn conn_mut_and_then(&self, func: F) -> Result + where + F: FnOnce(&mut Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + self.get().conn_mut_and_then(func).await + } + /// Closes the underlying sqlite connections. /// /// After this method returns, all calls to `self::conn()` or @@ -365,6 +391,34 @@ impl Pool { self.get().conn_mut_blocking(func) } + /// Invokes the provided function with a [`rusqlite::Connection`], blocking + /// the current thread. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub fn conn_and_then_blocking(&self, func: F) -> Result + where + F: FnOnce(&Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + self.get().conn_and_then_blocking(func) + } + + /// Invokes the provided function with a mutable [`rusqlite::Connection`], + /// blocking the current thread. + /// + /// Maps the result error type to a custom error; designed to be + /// used in conjunction with [`query_and_then`](https://docs.rs/rusqlite/latest/rusqlite/struct.CachedStatement.html#method.query_and_then). + pub fn conn_mut_and_then_blocking(&self, func: F) -> Result + where + F: FnOnce(&mut Connection) -> Result + Send + 'static, + T: Send + 'static, + E: From + From + Send + 'static, + { + self.get().conn_mut_and_then_blocking(func) + } + /// Closes the underlying sqlite connections, blocking the current thread. /// /// After this method returns, all calls to `self::conn_blocking()` or diff --git a/tests/tests.rs b/tests/tests.rs index 4014ed9..8201e04 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -28,6 +28,32 @@ fn journal_modes() -> [(JournalMode, &'static str); 6] { ] } +#[derive(Debug)] +enum CustomError { + AsyncSqlite, + Rusqlite, + User(&'static str), +} + +impl From for CustomError { + fn from(_value: Error) -> Self { + Self::AsyncSqlite + } +} + +impl From for CustomError { + fn from(_value: rusqlite::Error) -> Self { + Self::Rusqlite + } +} + +fn assert_user_error(result: Result<(), CustomError>, expected: &'static str) { + match result { + Err(CustomError::User(actual)) => assert_eq!(actual, expected), + other => panic!("expected CustomError::User({expected:?}), got {other:?}"), + } +} + #[test] fn test_blocking_client() { let tmp_dir = tempfile::tempdir().unwrap(); @@ -59,6 +85,39 @@ fn test_blocking_client() { client.close_blocking().expect("closing client conn"); } +#[test] +fn test_blocking_client_and_then_api() { + let client = ClientBuilder::new() + .open_blocking() + .expect("client unable to be opened"); + + client + .conn_and_then_blocking(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val INTEGER NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", [42])?; + Ok::<(), CustomError>(()) + }) + .expect("writing schema and seed data"); + + let val: i64 = client + .conn_mut_and_then_blocking(|conn| { + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| row.get(0)) + .map_err(CustomError::from) + }) + .expect("querying for result"); + assert_eq!(val, 42); + + assert_user_error( + client.conn_and_then_blocking(|_| Err(CustomError::User("client"))), + "client", + ); + + client.close_blocking().expect("closing client conn"); +} + #[test] fn test_blocking_default_pool_in_memory_uses_one_connection() { let pool = PoolBuilder::new() @@ -125,6 +184,38 @@ fn test_blocking_pool() { pool.close_blocking().expect("closing client conn"); } +#[test] +fn test_blocking_pool_and_then_api() { + let pool = PoolBuilder::new() + .open_blocking() + .expect("pool unable to be opened"); + + pool.conn_and_then_blocking(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val INTEGER NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", [42])?; + Ok::<(), CustomError>(()) + }) + .expect("writing schema and seed data"); + + let val: i64 = pool + .conn_mut_and_then_blocking(|conn| { + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| row.get(0)) + .map_err(CustomError::from) + }) + .expect("querying for result"); + assert_eq!(val, 42); + + assert_user_error( + pool.conn_and_then_blocking(|_| Err(CustomError::User("pool"))), + "pool", + ); + + pool.close_blocking().expect("closing pool"); +} + #[test] fn test_blocking_pool_rejects_multi_connection_anonymous_memory() { let err = match PoolBuilder::new().num_conns(2).open_blocking() { @@ -208,6 +299,7 @@ async_test!(test_journal_mode); async_test!(test_concurrency); async_test!(test_default_pool_in_memory_uses_one_connection); async_test!(test_pool); +async_test!(test_pool_and_then_api); async_test!(test_pool_rejects_multi_connection_anonymous_memory); async_test!(test_shared_memory_pool); async_test!(test_shared_memory_rejects_empty_name); @@ -346,6 +438,41 @@ async fn test_pool() { .expect("collecting query results"); } +async fn test_pool_and_then_api() { + let pool = PoolBuilder::new() + .open() + .await + .expect("pool unable to be opened"); + + pool.conn_and_then(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val INTEGER NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", [42])?; + Ok::<(), CustomError>(()) + }) + .await + .expect("writing schema and seed data"); + + let val: i64 = pool + .conn_mut_and_then(|conn| { + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| row.get(0)) + .map_err(CustomError::from) + }) + .await + .expect("querying for result"); + assert_eq!(val, 42); + + assert_user_error( + pool.conn_and_then(|_| Err(CustomError::User("pool async"))) + .await, + "pool async", + ); + + pool.close().await.expect("closing pool"); +} + async fn test_pool_rejects_multi_connection_anonymous_memory() { let err = match PoolBuilder::new().num_conns(2).open().await { Ok(pool) => {