diff --git a/AGENTS.md b/AGENTS.md index 868fb53..65ac10e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -28,9 +28,9 @@ The library has three core types in `src/`: - **Client** (`client.rs`): Wraps a single SQLite connection. Spawns a background `std::thread` that receives commands (closures) via a `crossbeam_channel`. Results are returned through `futures_channel::oneshot`. This design makes it runtime-agnostic. Client is cheaply cloneable. -- **Pool** (`pool.rs`): Manages multiple `Client` instances with round-robin selection via an atomic counter. Provides the same API as Client plus `conn_for_each()` for executing on all connections. Defaults to CPU-count connections. +- **Pool** (`pool.rs`): Manages multiple `Client` instances with round-robin selection via an atomic counter. Provides the same API as Client plus `conn_for_each()` for executing on all connections. File-backed and named shared-memory pools default to CPU-count connections; anonymous in-memory pools default to one connection and reject explicit multi-connection configuration. -- **Error** (`error.rs`): Non-exhaustive enum wrapping `rusqlite::Error`, channel errors, and pragma failures. +- **Error** (`error.rs`): Non-exhaustive enum wrapping config errors, `rusqlite::Error`, channel errors, panics, and pragma failures. All database operations use a closure-based API (e.g., `conn(|conn| { ... })`) to avoid lifetime issues with the cross-thread boundary. Both blocking and async variants exist for all operations. diff --git a/README.md b/README.md index b329cfb..f2fcde8 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,13 @@ println!("Value is: {value}"); A `Pool` represents a collection of background sqlite3 connections that can be called concurrently from any thread in your program. +`PoolBuilder::new().open()` and `path(":memory:")` use a single anonymous +in-memory connection by default, since separate SQLite `:memory:` connections +do not share schema or data. File-backed pools default to the logical CPU +count. For multiple connections to a named in-memory database, use +`shared_memory("name")`; this uses SQLite shared-cache mode, which has caveats, +so prefer a file-backed database when possible. + To create a sqlite pool and run a query: ```rust diff --git a/src/error.rs b/src/error.rs index 3cc2ae9..6621a5d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,8 @@ pub enum Error { /// Indicates that the connection to the sqlite database is closed. Closed, + /// Invalid builder configuration. + Config { message: &'static str }, /// Error updating PRAGMA. PragmaUpdate { name: &'static str, @@ -29,6 +31,7 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Closed => write!(f, "connection to sqlite database closed"), + Error::Config { message } => write!(f, "invalid configuration: {message}"), Error::PragmaUpdate { exp, got, name } => { write!(f, "updating pragma {name}: expected '{exp}', got '{got}'") } diff --git a/src/lib.rs b/src/lib.rs index 4de59ae..12badf0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,13 @@ //! A `Pool` represents a collection of background sqlite3 connections that can be //! called concurrently from any thread in your program. //! +//! `PoolBuilder::new().open()` and `path(":memory:")` use a single anonymous +//! in-memory connection by default, since separate SQLite `:memory:` +//! connections do not share schema or data. File-backed pools default to the +//! logical CPU count. For multiple connections to a named in-memory database, +//! use `shared_memory("name")`; this uses SQLite shared-cache mode, which has +//! caveats, so prefer a file-backed database when possible. +//! //! To create a sqlite pool and run a query: //! //! ```rust diff --git a/src/pool.rs b/src/pool.rs index 1d84238..306d9c8 100644 --- a/src/pool.rs +++ b/src/pool.rs @@ -34,6 +34,7 @@ use rusqlite::{Connection, OpenFlags}; #[derive(Clone, Debug, Default)] pub struct PoolBuilder { path: Option, + shared_memory_name: Option, flags: OpenFlags, journal_mode: Option, vfs: Option, @@ -51,6 +52,30 @@ impl PoolBuilder { /// By default, an in-memory database is used. pub fn path>(mut self, path: P) -> Self { self.path = Some(path.as_ref().into()); + self.shared_memory_name = None; + self + } + + /// Use a named shared in-memory sqlite database. + /// + /// This opens connections with a URI of the form + /// `file:?mode=memory&cache=shared` and enables + /// [`OpenFlags::SQLITE_OPEN_URI`] and + /// [`OpenFlags::SQLITE_OPEN_SHARED_CACHE`]. + /// + /// SQLite shared-cache mode has caveats and is discouraged by SQLite for + /// many workloads. Prefer a file-backed database when possible. The + /// in-memory database is deleted after the last connection using this name + /// is closed. + /// + /// ``` + /// use async_sqlite::PoolBuilder; + /// + /// let builder = PoolBuilder::new().shared_memory("my-pool").num_conns(2); + /// ``` + pub fn shared_memory>(mut self, name: N) -> Self { + self.path = None; + self.shared_memory_name = Some(name.as_ref().to_owned()); self } @@ -78,8 +103,11 @@ impl PoolBuilder { /// Specify the number of sqlite connections to open as part of the pool. /// - /// Defaults to the number of logical CPUs of the current system. Values - /// less than `1` are clamped to `1`. + /// File-backed and shared-memory pools default to the number of logical + /// CPUs of the current system. Anonymous in-memory pools, including + /// `path(":memory:")`, default to `1` connection because each sqlite + /// `:memory:` connection is a separate database. Values less than `1` are + /// clamped to `1`. /// /// ``` /// use async_sqlite::PoolBuilder; @@ -104,30 +132,16 @@ impl PoolBuilder { /// ``` pub async fn open(self) -> Result { let num_conns = self.get_num_conns(); + self.validate(num_conns)?; // Open the first connection with full config (including journal_mode). // This must complete before opening remaining connections to avoid // concurrent PRAGMA writes on a new database file. - let first = ClientBuilder { - path: self.path.clone(), - flags: self.flags, - journal_mode: self.journal_mode, - vfs: self.vfs.clone(), - } - .open() - .await?; + let first = self.client_builder().open().await?; // Open remaining connections with journal_mode too, so connection-local // modes are applied consistently across the pool. - let opens = (1..num_conns).map(|_| { - ClientBuilder { - path: self.path.clone(), - flags: self.flags, - journal_mode: self.journal_mode, - vfs: self.vfs.clone(), - } - .open() - }); + let opens = (1..num_conns).map(|_| self.client_builder().open()); let mut clients = vec![first]; clients.extend( join_all(opens) @@ -158,30 +172,17 @@ impl PoolBuilder { /// ``` pub fn open_blocking(self) -> Result { let num_conns = self.get_num_conns(); + self.validate(num_conns)?; // Open the first connection with full config (including journal_mode). - let first = ClientBuilder { - path: self.path.clone(), - flags: self.flags, - journal_mode: self.journal_mode, - vfs: self.vfs.clone(), - } - .open_blocking()?; + let first = self.client_builder().open_blocking()?; // Open remaining connections with journal_mode too, so connection-local // modes are applied consistently across the pool. let mut clients = vec![first]; clients.extend( (1..num_conns) - .map(|_| { - ClientBuilder { - path: self.path.clone(), - flags: self.flags, - journal_mode: self.journal_mode, - vfs: self.vfs.clone(), - } - .open_blocking() - }) + .map(|_| self.client_builder().open_blocking()) .collect::, Error>>()?, ); @@ -194,11 +195,95 @@ impl PoolBuilder { } fn get_num_conns(&self) -> usize { - self.num_conns.unwrap_or_else(|| { - available_parallelism() - .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap()) - .into() - }) + if let Some(num_conns) = self.num_conns { + return num_conns; + } + + if self.is_anonymous_memory() { + return 1; + } + + available_parallelism() + .unwrap_or_else(|_| NonZeroUsize::new(1).unwrap()) + .into() + } + + fn validate(&self, num_conns: usize) -> Result<(), Error> { + if self + .shared_memory_name + .as_ref() + .is_some_and(|name| name.is_empty()) + { + return Err(Error::Config { + message: "shared memory database name must not be empty", + }); + } + + if self.is_anonymous_memory() && num_conns > 1 { + return Err(Error::Config { + message: "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools", + }); + } + + Ok(()) + } + + fn client_builder(&self) -> ClientBuilder { + ClientBuilder { + path: self.connection_path(), + flags: self.connection_flags(), + journal_mode: self.journal_mode, + vfs: self.vfs.clone(), + } + } + + fn connection_path(&self) -> Option { + self.shared_memory_name + .as_deref() + .map(shared_memory_uri) + .or_else(|| self.path.clone()) + } + + fn connection_flags(&self) -> OpenFlags { + let mut flags = self.flags; + if self.shared_memory_name.is_some() { + flags.insert(OpenFlags::SQLITE_OPEN_URI); + flags.insert(OpenFlags::SQLITE_OPEN_SHARED_CACHE); + flags.remove(OpenFlags::SQLITE_OPEN_PRIVATE_CACHE); + } + flags + } + + fn is_anonymous_memory(&self) -> bool { + self.shared_memory_name.is_none() + && self + .path + .as_deref() + .is_none_or(|path| path == Path::new(":memory:")) + } +} + +fn shared_memory_uri(name: &str) -> PathBuf { + let mut uri = String::from("file:"); + push_uri_encoded(name, &mut uri); + uri.push_str("?mode=memory&cache=shared"); + uri.into() +} + +fn push_uri_encoded(input: &str, out: &mut String) { + const HEX: &[u8; 16] = b"0123456789ABCDEF"; + + for byte in input.bytes() { + match byte { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'.' | b'_' | b'~' => { + out.push(byte.into()); + } + _ => { + out.push('%'); + out.push(HEX[(byte >> 4) as usize].into()); + out.push(HEX[(byte & 0x0F) as usize].into()); + } + } } } diff --git a/tests/tests.rs b/tests/tests.rs index 2de34de..02cfb67 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -1,5 +1,21 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + use async_sqlite::{ClientBuilder, Error, JournalMode, PoolBuilder}; +static SHARED_MEMORY_ID: AtomicUsize = AtomicUsize::new(0); + +fn shared_memory_name(prefix: &str) -> String { + let id = SHARED_MEMORY_ID.fetch_add(1, Ordering::Relaxed); + format!("{prefix}-{}-{id}", std::process::id()) +} + +fn assert_config_message(err: Error, expected: &str) { + match err { + Error::Config { message } => assert_eq!(message, expected), + other => panic!("expected Error::Config, got {other:?}"), + } +} + fn journal_modes() -> [(JournalMode, &'static str); 6] { [ (JournalMode::Delete, "delete"), @@ -42,6 +58,43 @@ fn test_blocking_client() { client.close_blocking().expect("closing client conn"); } +#[test] +fn test_blocking_default_pool_in_memory_uses_one_connection() { + let pool = PoolBuilder::new() + .open_blocking() + .expect("pool unable to be opened"); + + pool.conn_blocking(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", ["value1"]) + }) + .expect("writing schema and seed data"); + + pool.conn_blocking(|conn| { + let val: String = + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| row.get(0))?; + assert_eq!(val, "value1"); + Ok(()) + }) + .expect("querying for result"); + + let results = pool.conn_for_each_blocking(|_| Ok(())); + assert_eq!(results.len(), 1); + + pool.close_blocking().expect("closing pool"); + + let pool = PoolBuilder::new() + .path(":memory:") + .open_blocking() + .expect("pool unable to be opened"); + let results = pool.conn_for_each_blocking(|_| Ok(())); + assert_eq!(results.len(), 1); + pool.close_blocking().expect("closing pool"); +} + #[test] fn test_blocking_pool() { let tmp_dir = tempfile::tempdir().unwrap(); @@ -71,6 +124,39 @@ fn test_blocking_pool() { pool.close_blocking().expect("closing client conn"); } +#[test] +fn test_blocking_pool_rejects_multi_connection_anonymous_memory() { + let err = match PoolBuilder::new().num_conns(2).open_blocking() { + Ok(pool) => { + pool.close_blocking().expect("closing unexpected pool"); + panic!("expected pool open to fail"); + } + Err(err) => err, + }; + + assert_config_message( + err, + "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools", + ); + + let err = match PoolBuilder::new() + .path(":memory:") + .num_conns(2) + .open_blocking() + { + Ok(pool) => { + pool.close_blocking().expect("closing unexpected pool"); + panic!("expected pool open to fail"); + } + Err(err) => err, + }; + + assert_config_message( + err, + "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools", + ); +} + #[test] fn test_blocking_pool_journal_mode() { for (journal_mode, expected) in journal_modes() { @@ -119,7 +205,11 @@ macro_rules! async_test { async_test!(test_journal_mode); async_test!(test_concurrency); +async_test!(test_default_pool_in_memory_uses_one_connection); async_test!(test_pool); +async_test!(test_pool_rejects_multi_connection_anonymous_memory); +async_test!(test_shared_memory_pool); +async_test!(test_shared_memory_rejects_empty_name); async_test!(test_pool_journal_mode); async_test!(test_pool_conn_for_each); async_test!(test_pool_close_concurrent); @@ -178,6 +268,46 @@ async fn test_concurrency() { .expect("collecting query results"); } +async fn test_default_pool_in_memory_uses_one_connection() { + let pool = PoolBuilder::new() + .open() + .await + .expect("pool unable to be opened"); + + pool.conn(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", ["value1"]) + }) + .await + .expect("writing schema and seed data"); + + pool.conn(|conn| { + let val: String = + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| row.get(0))?; + assert_eq!(val, "value1"); + Ok(()) + }) + .await + .expect("querying for result"); + + let results = pool.conn_for_each(|_| Ok(())).await; + assert_eq!(results.len(), 1); + + pool.close().await.expect("closing pool"); + + let pool = PoolBuilder::new() + .path(":memory:") + .open() + .await + .expect("pool unable to be opened"); + let results = pool.conn_for_each(|_| Ok(())).await; + assert_eq!(results.len(), 1); + pool.close().await.expect("closing pool"); +} + async fn test_pool() { let tmp_dir = tempfile::tempdir().unwrap(); let pool = PoolBuilder::new() @@ -212,6 +342,88 @@ async fn test_pool() { .expect("collecting query results"); } +async fn test_pool_rejects_multi_connection_anonymous_memory() { + let err = match PoolBuilder::new().num_conns(2).open().await { + Ok(pool) => { + pool.close().await.expect("closing unexpected pool"); + panic!("expected pool open to fail"); + } + Err(err) => err, + }; + + assert_config_message( + err, + "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools", + ); + + let err = match PoolBuilder::new() + .path(":memory:") + .num_conns(2) + .open() + .await + { + Ok(pool) => { + pool.close().await.expect("closing unexpected pool"); + panic!("expected pool open to fail"); + } + Err(err) => err, + }; + + assert_config_message( + err, + "anonymous in-memory pools cannot use multiple connections; call path(...) for file-backed pools or shared_memory(...) for named shared in-memory pools", + ); +} + +async fn test_shared_memory_pool() { + let name = shared_memory_name("shared-pool"); + let pool = PoolBuilder::new() + .shared_memory(&name) + .num_conns(2) + .open() + .await + .expect("pool unable to be opened"); + + let results = pool.conn_for_each(|_| Ok(())).await; + assert_eq!(results.len(), 2); + + pool.conn(|conn| { + conn.execute( + "CREATE TABLE testing (id INTEGER PRIMARY KEY, val TEXT NOT NULL)", + (), + )?; + conn.execute("INSERT INTO testing VALUES (1, ?)", ["value1"]) + }) + .await + .expect("writing schema and seed data"); + + let results = pool + .conn_for_each(|conn| { + conn.query_row("SELECT val FROM testing WHERE id=?", [1], |row| { + row.get::<_, String>(0) + }) + }) + .await; + + for result in results { + assert_eq!(result.unwrap(), "value1"); + } + + pool.close().await.expect("closing pool"); +} + +async fn test_shared_memory_rejects_empty_name() { + let err = match PoolBuilder::new().shared_memory("").open().await { + Ok(pool) => { + pool.close().await.expect("closing unexpected pool"); + panic!("expected pool open to fail"); + } + Err(err) => err, + }; + + assert_config_message(err, "shared memory database name must not be empty"); +} + async fn test_pool_journal_mode() { for (journal_mode, expected) in journal_modes() { let tmp_dir = tempfile::tempdir().unwrap();