diff --git a/src/database.rs b/src/database.rs index c3ca057..6900088 100644 --- a/src/database.rs +++ b/src/database.rs @@ -976,7 +976,9 @@ impl Database { i.tick().await; for table in db.read().await._db.tables.read().values() { - table.prune_dead_watchers(); + if let Some(table) = table.upgrade() { + table.prune_dead_watchers(); + } } } }); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 6433059..6717013 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -8,7 +8,7 @@ use std::{ future::Future, path::{Path, PathBuf}, pin::Pin, - sync::Arc, + sync::{Arc, Weak}, }; use thread_local::ThreadLocal; use tokio::sync::watch; @@ -47,7 +47,7 @@ pub struct Engine { path: PathBuf, cache_size_per_thread: u32, - pub(in crate::database) tables: RwLock>>, + pub(in crate::database) tables: RwLock>>, } impl Engine { @@ -111,18 +111,35 @@ impl DatabaseEngine for Engine { } fn open_tree(self: &Arc, name: &str) -> Result> { + fn create_new(engine: &Arc, name: &str) -> Result { + engine.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; + + SqliteTable { + engine: Arc::clone(engine), + name: name.to_owned(), + watchers: RwLock::new(HashMap::new()), + } + } + + // Table mappings are `Weak` to prevent reference cycles, that creates this additional correctness logic. Ok(match self.tables.write().entry(name.to_string()) { - hash_map::Entry::Occupied(o) => o.get().clone(), - hash_map::Entry::Vacant(v) => { - self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; + hash_map::Entry::Occupied(o) => { + if let Some(table) = o.get().upgrade() { + table + } else { + // On the off-chance that a table was dropped somewhere... - let table = Arc::new(SqliteTable { - engine: Arc::clone(self), - name: name.to_owned(), - watchers: RwLock::new(HashMap::new()), - }); + let table = Arc::new(create_new(self, name)?); + + o.insert(table.downgrade()); + + table + } + } + hash_map::Entry::Vacant(v) => { + let table = Arc::new(create_new(self, name)?); - v.insert(table.clone()); + v.insert(table.downgrade()); table }