From 22e3416745b20f85a93b9e5e5fefb39fce1ceca8 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Jul 2021 01:13:50 +0200 Subject: [PATCH 01/28] YEET --- Cargo.lock | 137 ++++++++++++- Cargo.toml | 8 +- docker-compose.yml | 8 +- rjbench_testing/docker-compose.yml | 32 +++ src/database.rs | 7 +- src/database/abstraction.rs | 296 +--------------------------- src/database/abstraction/rocksdb.rs | 176 +++++++++++++++++ src/database/abstraction/sled.rs | 115 +++++++++++ src/database/abstraction/sqlite.rs | 273 +++++++++++++++++++++++++ src/database/account_data.rs | 2 +- src/database/appservice.rs | 4 +- src/database/pusher.rs | 2 +- src/database/rooms.rs | 18 +- src/database/sending.rs | 2 +- src/error.rs | 6 + 15 files changed, 772 insertions(+), 314 deletions(-) create mode 100644 rjbench_testing/docker-compose.yml create mode 100644 src/database/abstraction/rocksdb.rs create mode 100644 src/database/abstraction/sled.rs create mode 100644 src/database/abstraction/sqlite.rs diff --git a/Cargo.lock b/Cargo.lock index 76e727e..9266c13 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6,6 +6,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" +[[package]] +name = "ahash" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" +dependencies = [ + "getrandom 0.2.3", + "once_cell", + "version_check", +] + [[package]] name = "aho-corasick" version = "0.7.15" @@ -238,6 +249,7 @@ version = "0.1.0" dependencies = [ "base64 0.13.0", "bytes", + "crossbeam", "directories", "http", "image", @@ -246,6 +258,7 @@ dependencies = [ "lru-cache", "opentelemetry", "opentelemetry-jaeger", + "parking_lot", "pretty_env_logger", "rand 0.8.3", "regex", @@ -254,6 +267,7 @@ dependencies = [ "rocket", "rocksdb", "ruma", + "rusqlite", "rust-argon2", "rustls", "rustls-native-certs", @@ -339,11 +353,46 @@ dependencies = [ "cfg-if 1.0.0", ] +[[package]] +name = "crossbeam" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ae5588f6b3c3cb05239e90bd110f257254aecd01e4635400391aeae07497845" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ed27e177f16d65f0f0c22a213e17c696ace5dd64b14258b52f9417ccb52db4" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94af6efb46fef72616855b036a624cf27ba656ffc9be1b9a3c931cfc7749a9a9" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" -version = "0.9.3" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2584f639eb95fea8c798496315b297cf81b9b58b6d30ab066a75455333cf4b12" +checksum = "4ec02e091aa634e2c3ada4a392989e7c3116673ef0ac5b72232439094d73b7fd" dependencies = [ "cfg-if 1.0.0", "crossbeam-utils", @@ -352,13 +401,22 @@ dependencies = [ "scopeguard", ] +[[package]] +name = "crossbeam-queue" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b10ddc024425c88c2ad148c1b0fd53f4c6d38db9697c9f1588381212fa657c9" +dependencies = [ + "cfg-if 1.0.0", + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7e9d99fa91428effe99c5c6d4634cdeba32b8cf784fc428a2a687f61a952c49" +checksum = "d82cfc11ce7f2c3faef78d8a684447b40d503d9681acebed6cb728d45940c4db" dependencies = [ - "autocfg", "cfg-if 1.0.0", "lazy_static", ] @@ -547,6 +605,18 @@ dependencies = [ "termcolor", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + [[package]] name = "figment" version = "0.10.5" @@ -774,6 +844,24 @@ version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" +dependencies = [ + "hashbrown 0.11.2", +] + [[package]] name = "heck" version = "0.3.2" @@ -920,7 +1008,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.9.1", "serde", ] @@ -1083,6 +1171,16 @@ dependencies = [ "libc", ] +[[package]] +name = "libsqlite3-sys" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" +dependencies = [ + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.4" @@ -1484,6 +1582,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "pkg-config" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3831453b3449ceb48b6d9c7ad7c96d5ea673e9b470a1dc578c2ce6521230884c" + [[package]] name = "png" version = "0.16.8" @@ -2136,6 +2240,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "rusqlite" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57adcf67c8faaf96f3248c2a7b419a0dbc52ebe36ba83dd57fe83827c1ea4eb3" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "memchr", + "smallvec", +] + [[package]] name = "rust-argon2" version = "0.8.3" @@ -3007,6 +3126,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "vcpkg" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "025ce40a007e1907e58d5bc1a594def78e5573bb0b1160bc389634e8f12e4faa" + [[package]] name = "version_check" version = "0.9.3" diff --git a/Cargo.toml b/Cargo.toml index 426d242..17b1841 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ tokio = "1.2.0" # Used for storing data permanently sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } +# sqlx = { version = "0.5.5", features = ["sqlite", "runtime-tokio-rustls"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } # Used for the http request / response body type for Ruma endpoints used with reqwest @@ -73,11 +74,16 @@ tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" lru-cache = "0.1.2" +rusqlite = { version = "0.25.3", optional = true } +parking_lot = { version = "0.11.1", optional = true } +crossbeam = { version = "0.8.1", optional = true } [features] -default = ["conduit_bin", "backend_sled"] +default = ["conduit_bin", "backend_sqlite"] backend_sled = ["sled"] backend_rocksdb = ["rocksdb"] +backend_sqlite = ["sqlite"] +sqlite = ["rusqlite", "parking_lot", "crossbeam"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/docker-compose.yml b/docker-compose.yml index cfc2462..fe13fdc 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,13 +14,13 @@ services: args: CREATED: '2021-03-16T08:18:27Z' VERSION: '0.1.0' - LOCAL: 'false' + LOCAL: 'true' GIT_REF: origin/master restart: unless-stopped ports: - 8448:8000 volumes: - - db:/srv/conduit/.local/share/conduit + - ./target/db:/srv/conduit/.local/share/conduit ### Uncomment if you want to use conduit.toml to configure Conduit ### Note: Set env vars will override conduit.toml values # - ./conduit.toml:/srv/conduit/conduit.toml @@ -55,5 +55,5 @@ services: # depends_on: # - homeserver -volumes: - db: +# volumes: +# db: diff --git a/rjbench_testing/docker-compose.yml b/rjbench_testing/docker-compose.yml new file mode 100644 index 0000000..44da5b5 --- /dev/null +++ b/rjbench_testing/docker-compose.yml @@ -0,0 +1,32 @@ +# Conduit +version: '3' + +services: + homeserver: + image: ubuntu:21.04 + restart: unless-stopped + working_dir: "/srv/conduit" + entrypoint: /srv/conduit/conduit + ports: + - 8448:8000 + volumes: + - ../target/db:/srv/conduit/.local/share/conduit + - ../target/debug/conduit:/srv/conduit/conduit + - ./conduit.toml:/srv/conduit/conduit.toml:ro + environment: + # CONDUIT_SERVER_NAME: localhost:8000 # replace with your own name + # CONDUIT_TRUSTED_SERVERS: '["matrix.org"]' + ### Uncomment and change values as desired + # CONDUIT_ADDRESS: 127.0.0.1 + # CONDUIT_PORT: 8000 + CONDUIT_CONFIG: '/srv/conduit/conduit.toml' + # Available levels are: error, warn, info, debug, trace - more info at: https://docs.rs/env_logger/*/env_logger/#enabling-logging + # CONDUIT_LOG: debug # default is: "info,rocket=off,_=off,sled=off" + # CONDUIT_ALLOW_JAEGER: 'false' + # CONDUIT_ALLOW_REGISTRATION : 'false' + # CONDUIT_ALLOW_ENCRYPTION: 'false' + # CONDUIT_ALLOW_FEDERATION: 'false' + # CONDUIT_DATABASE_PATH: /srv/conduit/.local/share/conduit + # CONDUIT_WORKERS: 10 + # CONDUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB + diff --git a/src/database.rs b/src/database.rs index ec4052c..76e4ed0 100644 --- a/src/database.rs +++ b/src/database.rs @@ -84,10 +84,13 @@ fn default_log() -> String { } #[cfg(feature = "sled")] -pub type Engine = abstraction::SledEngine; +pub type Engine = abstraction::sled::SledEngine; #[cfg(feature = "rocksdb")] -pub type Engine = abstraction::RocksDbEngine; +pub type Engine = abstraction::rocksdb::RocksDbEngine; + +#[cfg(feature = "sqlite")] +pub type Engine = abstraction::sqlite::SqliteEngine; pub struct Database { pub globals: globals::Globals, diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index f81c9de..1ab5dea 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,24 +1,16 @@ use super::Config; -use crate::{utils, Result}; -use log::warn; +use crate::Result; + use std::{future::Future, pin::Pin, sync::Arc}; #[cfg(feature = "rocksdb")] -use std::{collections::BTreeMap, sync::RwLock}; +pub mod rocksdb; #[cfg(feature = "sled")] -pub struct SledEngine(sled::Db); -#[cfg(feature = "sled")] -pub struct SledEngineTree(sled::Tree); +pub mod sled; -#[cfg(feature = "rocksdb")] -pub struct RocksDbEngine(rocksdb::DBWithThreadMode); -#[cfg(feature = "rocksdb")] -pub struct RocksDbEngineTree<'a> { - db: Arc, - name: &'a str, - watchers: RwLock, Vec>>>, -} +#[cfg(feature = "sqlite")] +pub mod sqlite; pub trait DatabaseEngine: Sized { fn open(config: &Config) -> Result>; @@ -32,20 +24,20 @@ pub trait Tree: Send + Sync { fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a>; fn iter_from<'a>( &'a self, from: &[u8], backwards: bool, - ) -> Box, Box<[u8]>)> + 'a>; + ) -> Box, Vec)> + Send + 'a>; fn increment(&self, key: &[u8]) -> Result>; fn scan_prefix<'a>( &'a self, prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a>; + ) -> Box, Vec)> + Send + 'a>; fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; @@ -57,273 +49,3 @@ pub trait Tree: Send + Sync { Ok(()) } } - -#[cfg(feature = "sled")] -impl DatabaseEngine for SledEngine { - fn open(config: &Config) -> Result> { - Ok(Arc::new(SledEngine( - sled::Config::default() - .path(&config.database_path) - .cache_capacity(config.cache_capacity as u64) - .use_compression(true) - .open()?, - ))) - } - - fn open_tree(self: &Arc, name: &'static str) -> Result> { - Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) - } -} - -#[cfg(feature = "sled")] -impl Tree for SledEngineTree { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.0.get(key)?.map(|v| v.to_vec())) - } - - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - self.0.insert(key, value)?; - Ok(()) - } - - fn remove(&self, key: &[u8]) -> Result<()> { - self.0.remove(key)?; - Ok(()) - } - - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a> { - Box::new( - self.0 - .iter() - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), - ) - } - - fn iter_from( - &self, - from: &[u8], - backwards: bool, - ) -> Box, Box<[u8]>)>> { - let iter = if backwards { - self.0.range(..from) - } else { - self.0.range(from..) - }; - - let iter = iter - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); - - if backwards { - Box::new(iter.rev()) - } else { - Box::new(iter) - } - } - - fn increment(&self, key: &[u8]) -> Result> { - Ok(self - .0 - .update_and_fetch(key, utils::increment) - .map(|o| o.expect("increment always sets a value").to_vec())?) - } - - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a> { - let iter = self - .0 - .scan_prefix(prefix) - .filter_map(|r| { - if let Err(e) = &r { - warn!("Error: {}", e); - } - r.ok() - }) - .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); - - Box::new(iter) - } - - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - let prefix = prefix.to_vec(); - Box::pin(async move { - self.0.watch_prefix(prefix).await; - }) - } -} - -#[cfg(feature = "rocksdb")] -impl DatabaseEngine for RocksDbEngine { - fn open(config: &Config) -> Result> { - let mut db_opts = rocksdb::Options::default(); - db_opts.create_if_missing(true); - db_opts.set_max_open_files(16); - db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); - db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); - db_opts.set_target_file_size_base(256 << 20); - db_opts.set_write_buffer_size(256 << 20); - - let mut block_based_options = rocksdb::BlockBasedOptions::default(); - block_based_options.set_block_size(512 << 10); - db_opts.set_block_based_table_factory(&block_based_options); - - let cfs = rocksdb::DBWithThreadMode::::list_cf( - &db_opts, - &config.database_path, - ) - .unwrap_or_default(); - - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( - &db_opts, - &config.database_path, - cfs.iter() - .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), - )?; - - Ok(Arc::new(RocksDbEngine(db))) - } - - fn open_tree(self: &Arc, name: &'static str) -> Result> { - let mut options = rocksdb::Options::default(); - options.set_merge_operator_associative("increment", utils::increment_rocksdb); - - // Create if it doesn't exist - let _ = self.0.create_cf(name, &options); - - Ok(Arc::new(RocksDbEngineTree { - name, - db: Arc::clone(self), - watchers: RwLock::new(BTreeMap::new()), - })) - } -} - -#[cfg(feature = "rocksdb")] -impl RocksDbEngineTree<'_> { - fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { - self.db.0.cf_handle(self.name).unwrap() - } -} - -#[cfg(feature = "rocksdb")] -impl Tree for RocksDbEngineTree<'_> { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.db.0.get_cf(self.cf(), key)?) - } - - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let watchers = self.watchers.read().unwrap(); - let mut triggered = Vec::new(); - - for length in 0..=key.len() { - if watchers.contains_key(&key[..length]) { - triggered.push(&key[..length]); - } - } - - drop(watchers); - - if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); - for prefix in triggered { - if let Some(txs) = watchers.remove(prefix) { - for tx in txs { - let _ = tx.send(()); - } - } - } - } - - Ok(self.db.0.put_cf(self.cf(), key, value)?) - } - - fn remove(&self, key: &[u8]) -> Result<()> { - Ok(self.db.0.delete_cf(self.cf(), key)?) - } - - fn iter<'a>(&'a self) -> Box, Box<[u8]>)> + Send + Sync + 'a> { - Box::new( - self.db - .0 - .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), - ) - } - - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box, Box<[u8]>)> + 'a> { - Box::new(self.db.0.iterator_cf( - self.cf(), - rocksdb::IteratorMode::From( - from, - if backwards { - rocksdb::Direction::Reverse - } else { - rocksdb::Direction::Forward - }, - ), - )) - } - - fn increment(&self, key: &[u8]) -> Result> { - let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); - dbg!(stats.mem_table_total); - dbg!(stats.mem_table_unflushed); - dbg!(stats.mem_table_readers_total); - dbg!(stats.cache_total); - // TODO: atomic? - let old = self.get(key)?; - let new = utils::increment(old.as_deref()).unwrap(); - self.insert(key, &new)?; - Ok(new) - } - - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Box<[u8]>)> + Send + 'a> { - Box::new( - self.db - .0 - .iterator_cf( - self.cf(), - rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), - ) - .take_while(move |(k, _)| k.starts_with(&prefix)), - ) - } - - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - let (tx, rx) = tokio::sync::oneshot::channel(); - - self.watchers - .write() - .unwrap() - .entry(prefix.to_vec()) - .or_default() - .push(tx); - - Box::pin(async move { - // Tx is never destroyed - rx.await.unwrap(); - }) - } -} diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs new file mode 100644 index 0000000..88b6297 --- /dev/null +++ b/src/database/abstraction/rocksdb.rs @@ -0,0 +1,176 @@ +use super::super::Config; +use crate::{utils, Result}; + +use std::{future::Future, pin::Pin, sync::Arc}; + +use super::{DatabaseEngine, Tree}; + +use std::{collections::BTreeMap, sync::RwLock}; + +pub struct RocksDbEngine(rocksdb::DBWithThreadMode); + +pub struct RocksDbEngineTree<'a> { + db: Arc, + name: &'a str, + watchers: RwLock, Vec>>>, +} + +impl DatabaseEngine for RocksDbEngine { + fn open(config: &Config) -> Result> { + let mut db_opts = rocksdb::Options::default(); + db_opts.create_if_missing(true); + db_opts.set_max_open_files(16); + db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); + db_opts.set_compression_type(rocksdb::DBCompressionType::Snappy); + db_opts.set_target_file_size_base(256 << 20); + db_opts.set_write_buffer_size(256 << 20); + + let mut block_based_options = rocksdb::BlockBasedOptions::default(); + block_based_options.set_block_size(512 << 10); + db_opts.set_block_based_table_factory(&block_based_options); + + let cfs = rocksdb::DBWithThreadMode::::list_cf( + &db_opts, + &config.database_path, + ) + .unwrap_or_default(); + + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( + &db_opts, + &config.database_path, + cfs.iter() + .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), + )?; + + Ok(Arc::new(RocksDbEngine(db))) + } + + fn open_tree(self: &Arc, name: &'static str) -> Result> { + let mut options = rocksdb::Options::default(); + options.set_merge_operator_associative("increment", utils::increment_rocksdb); + + // Create if it doesn't exist + let _ = self.0.create_cf(name, &options); + + Ok(Arc::new(RocksDbEngineTree { + name, + db: Arc::clone(self), + watchers: RwLock::new(BTreeMap::new()), + })) + } +} + +impl RocksDbEngineTree<'_> { + fn cf(&self) -> rocksdb::BoundColumnFamily<'_> { + self.db.0.cf_handle(self.name).unwrap() + } +} + +impl Tree for RocksDbEngineTree<'_> { + fn get(&self, key: &[u8]) -> Result>> { + Ok(self.db.0.get_cf(self.cf(), key)?) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let watchers = self.watchers.read().unwrap(); + let mut triggered = Vec::new(); + + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } + + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write().unwrap(); + for prefix in triggered { + if let Some(txs) = watchers.remove(prefix) { + for tx in txs { + let _ = tx.send(()); + } + } + } + } + + Ok(self.db.0.put_cf(self.cf(), key, value)?) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + Ok(self.db.0.delete_cf(self.cf(), key)?) + } + + fn iter<'a>(&'a self) -> Box, Vec)> + Send + Sync + 'a> { + Box::new( + self.db + .0 + .iterator_cf(self.cf(), rocksdb::IteratorMode::Start), + ) + } + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box, Vec)> + 'a> { + Box::new(self.db.0.iterator_cf( + self.cf(), + rocksdb::IteratorMode::From( + from, + if backwards { + rocksdb::Direction::Reverse + } else { + rocksdb::Direction::Forward + }, + ), + )) + } + + fn increment(&self, key: &[u8]) -> Result> { + let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.db.0]), None).unwrap(); + dbg!(stats.mem_table_total); + dbg!(stats.mem_table_unflushed); + dbg!(stats.mem_table_readers_total); + dbg!(stats.cache_total); + // TODO: atomic? + let old = self.get(key)?; + let new = utils::increment(old.as_deref()).unwrap(); + self.insert(key, &new)?; + Ok(new) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box, Vec)> + Send + 'a> { + Box::new( + self.db + .0 + .iterator_cf( + self.cf(), + rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), + ) + .take_while(move |(k, _)| k.starts_with(&prefix)), + ) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.watchers + .write() + .unwrap() + .entry(prefix.to_vec()) + .or_default() + .push(tx); + + Box::pin(async move { + // Tx is never destroyed + rx.await.unwrap(); + }) + } +} diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs new file mode 100644 index 0000000..2f3fb34 --- /dev/null +++ b/src/database/abstraction/sled.rs @@ -0,0 +1,115 @@ +use super::super::Config; +use crate::{utils, Result}; +use log::warn; +use std::{future::Future, pin::Pin, sync::Arc}; + +use super::{DatabaseEngine, Tree}; + +pub struct SledEngine(sled::Db); + +pub struct SledEngineTree(sled::Tree); + +impl DatabaseEngine for SledEngine { + fn open(config: &Config) -> Result> { + Ok(Arc::new(SledEngine( + sled::Config::default() + .path(&config.database_path) + .cache_capacity(config.cache_capacity as u64) + .use_compression(true) + .open()?, + ))) + } + + fn open_tree(self: &Arc, name: &'static str) -> Result> { + Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) + } +} + +impl Tree for SledEngineTree { + fn get(&self, key: &[u8]) -> Result>> { + Ok(self.0.get(key)?.map(|v| v.to_vec())) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + self.0.insert(key, value)?; + Ok(()) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + self.0.remove(key)?; + Ok(()) + } + + fn iter<'a>(&'a self) -> Box, Vec)> + Send + Sync + 'a> { + Box::new( + self.0 + .iter() + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())), + ) + } + + fn iter_from( + &self, + from: &[u8], + backwards: bool, + ) -> Box, Vec)>> { + let iter = if backwards { + self.0.range(..from) + } else { + self.0.range(from..) + }; + + let iter = iter + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + if backwards { + Box::new(iter.rev()) + } else { + Box::new(iter) + } + } + + fn increment(&self, key: &[u8]) -> Result> { + Ok(self + .0 + .update_and_fetch(key, utils::increment) + .map(|o| o.expect("increment always sets a value").to_vec())?) + } + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box, Vec)> + Send + 'a> { + let iter = self + .0 + .scan_prefix(prefix) + .filter_map(|r| { + if let Err(e) = &r { + warn!("Error: {}", e); + } + r.ok() + }) + .map(|(k, v)| (k.to_vec().into(), v.to_vec().into())); + + Box::new(iter) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let prefix = prefix.to_vec(); + Box::pin(async move { + self.0.watch_prefix(prefix).await; + }) + } +} diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs new file mode 100644 index 0000000..35078f9 --- /dev/null +++ b/src/database/abstraction/sqlite.rs @@ -0,0 +1,273 @@ +use std::{future::Future, pin::Pin, sync::Arc, thread}; + +use crate::{database::Config, Result}; + +use super::{DatabaseEngine, Tree}; + +use std::{collections::BTreeMap, sync::RwLock}; + +use crossbeam::channel::{bounded, Sender as ChannelSender}; +use parking_lot::{Mutex, MutexGuard}; +use rusqlite::{params, Connection, OptionalExtension}; + +use tokio::sync::oneshot::Sender; + +type SqliteHandle = Arc>; + +// const SQL_CREATE_TABLE: &str = +// "CREATE TABLE IF NOT EXISTS {} {{ \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL }}"; +// const SQL_SELECT: &str = "SELECT value FROM {} WHERE key = ?"; +// const SQL_INSERT: &str = "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)"; +// const SQL_DELETE: &str = "DELETE FROM {} WHERE key = ?"; +// const SQL_SELECT_ITER: &str = "SELECT key, value FROM {}"; +// const SQL_SELECT_PREFIX: &str = "SELECT key, value FROM {} WHERE key LIKE ?||'%' ORDER BY key ASC"; +// const SQL_SELECT_ITER_FROM_FORWARDS: &str = "SELECT key, value FROM {} WHERE key >= ? ORDER BY ASC"; +// const SQL_SELECT_ITER_FROM_BACKWARDS: &str = +// "SELECT key, value FROM {} WHERE key <= ? ORDER BY DESC"; + +pub struct SqliteEngine { + handle: SqliteHandle, +} + +impl DatabaseEngine for SqliteEngine { + fn open(config: &Config) -> Result> { + let conn = Connection::open(format!("{}/conduit.db", &config.database_path))?; + + conn.pragma_update(None, "journal_mode", &"WAL".to_owned())?; + + let handle = Arc::new(Mutex::new(conn)); + + Ok(Arc::new(SqliteEngine { handle })) + } + + fn open_tree(self: &Arc, name: &str) -> Result> { + self.handle.lock().execute(format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name).as_str(), [])?; + + Ok(Arc::new(SqliteTable { + engine: Arc::clone(self), + name: name.to_owned(), + watchers: RwLock::new(BTreeMap::new()), + })) + } +} + +pub struct SqliteTable { + engine: Arc, + name: String, + watchers: RwLock, Vec>>>, +} + +type TupleOfBytes = (Vec, Vec); + +impl SqliteTable { + fn get_with_guard( + &self, + guard: &MutexGuard<'_, Connection>, + key: &[u8], + ) -> Result>> { + Ok(guard + .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? + .query_row([key], |row| row.get(0)) + .optional()?) + } + + fn insert_with_guard( + &self, + guard: &MutexGuard<'_, Connection>, + key: &[u8], + value: &[u8], + ) -> Result<()> { + guard.execute( + format!( + "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", + self.name + ) + .as_str(), + [key, value], + )?; + Ok(()) + } + + fn _iter_from_thread( + &self, + mutex: Arc>, + f: F, + ) -> Box + Send> + where + F: (FnOnce(MutexGuard<'_, Connection>, ChannelSender)) + Send + 'static, + { + let (s, r) = bounded::(5); + + thread::spawn(move || { + let _ = f(mutex.lock(), s); + }); + + Box::new(r.into_iter()) + } +} + +macro_rules! iter_from_thread { + ($self:expr, $sql:expr, $param:expr) => { + $self._iter_from_thread($self.engine.handle.clone(), move |guard, s| { + let _ = guard + .prepare($sql) + .unwrap() + .query_map($param, |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(|r| r.unwrap()) + .try_for_each(|bob| s.send(bob)); + }) + }; +} + +impl Tree for SqliteTable { + fn get(&self, key: &[u8]) -> Result>> { + self.get_with_guard(&self.engine.handle.lock(), key) + } + + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + self.insert_with_guard(&self.engine.handle.lock(), key, value)?; + + let watchers = self.watchers.read().unwrap(); + let mut triggered = Vec::new(); + + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } + + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write().unwrap(); + for prefix in triggered { + if let Some(txs) = watchers.remove(prefix) { + for tx in txs { + let _ = tx.send(()); + } + } + } + }; + + Ok(()) + } + + fn remove(&self, key: &[u8]) -> Result<()> { + self.engine.handle.lock().execute( + format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), + [key], + )?; + Ok(()) + } + + fn iter<'a>(&'a self) -> Box + Send + 'a> { + let name = self.name.clone(); + iter_from_thread!( + self, + format!("SELECT key, value FROM {}", name).as_str(), + params![] + ) + } + + fn iter_from<'a>( + &'a self, + from: &[u8], + backwards: bool, + ) -> Box + Send + 'a> { + let name = self.name.clone(); + let from = from.to_vec(); // TODO change interface? + if backwards { + iter_from_thread!( + self, + format!( // TODO change to <= on rebase + "SELECT key, value FROM {} WHERE key < ? ORDER BY key DESC", + name + ) + .as_str(), + [from] + ) + } else { + iter_from_thread!( + self, + format!( + "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + name + ) + .as_str(), + [from] + ) + } + } + + fn increment(&self, key: &[u8]) -> Result> { + let guard = self.engine.handle.lock(); + + let old = self.get_with_guard(&guard, key)?; + + let new = + crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); + + self.insert_with_guard(&guard, key, &new)?; + + Ok(new) + } + + // TODO: make this use take_while + + fn scan_prefix<'a>( + &'a self, + prefix: Vec, + ) -> Box + Send + 'a> { + // let name = self.name.clone(); + // iter_from_thread!( + // self, + // format!( + // "SELECT key, value FROM {} WHERE key BETWEEN ?1 AND ?1 || X'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF' ORDER BY key ASC", + // name + // ) + // .as_str(), + // [prefix] + // ) + Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix))) + } + + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.watchers + .write() + .unwrap() + .entry(prefix.to_vec()) + .or_default() + .push(tx); + + Box::pin(async move { + // Tx is never destroyed + rx.await.unwrap(); + }) + } + + fn clear(&self) -> Result<()> { + self.engine.handle.lock().execute( + format!("DELETE FROM {}", self.name).as_str(), + [], + )?; + Ok(()) + } +} + +// TODO +// struct Pool { +// writer: Mutex, +// readers: [Mutex; NUM_READERS], +// } + +// // then, to pick a reader: +// for r in &pool.readers { +// if let Ok(reader) = r.try_lock() { +// // use reader +// } +// } +// // none unlocked, pick the next reader +// pool.readers[pool.counter.fetch_add(1, Relaxed) % NUM_READERS].lock() diff --git a/src/database/account_data.rs b/src/database/account_data.rs index 2ba7bc3..b1d5b6b 100644 --- a/src/database/account_data.rs +++ b/src/database/account_data.rs @@ -127,7 +127,7 @@ impl AccountData { room_id: Option<&RoomId>, user_id: &UserId, kind: &EventType, - ) -> Result, Box<[u8]>)>> { + ) -> Result, Vec)>> { let mut prefix = room_id .map(|r| r.to_string()) .unwrap_or_default() diff --git a/src/database/appservice.rs b/src/database/appservice.rs index 4bf3a21..f39520c 100644 --- a/src/database/appservice.rs +++ b/src/database/appservice.rs @@ -49,7 +49,7 @@ impl Appservice { ) } - pub fn iter_ids(&self) -> Result> + Send + Sync + '_> { + pub fn iter_ids(&self) -> Result> + Send + '_> { Ok(self.id_appserviceregistrations.iter().map(|(id, _)| { utils::string_from_bytes(&id) .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) @@ -58,7 +58,7 @@ impl Appservice { pub fn iter_all( &self, - ) -> Result> + '_ + Send + Sync> { + ) -> Result> + '_ + Send> { Ok(self.iter_ids()?.filter_map(|id| id.ok()).map(move |id| { Ok(( id.clone(), diff --git a/src/database/pusher.rs b/src/database/pusher.rs index a27bf2c..3210cb1 100644 --- a/src/database/pusher.rs +++ b/src/database/pusher.rs @@ -73,7 +73,7 @@ impl PushData { pub fn get_pusher_senderkeys<'a>( &'a self, sender: &UserId, - ) -> impl Iterator> + 'a { + ) -> impl Iterator> + 'a { let mut prefix = sender.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index e23b804..eb1a924 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -1078,13 +1078,13 @@ impl Rooms { .scan_prefix(old_shortstatehash.clone()) // Chop the old_shortstatehash out leaving behind the short state key .map(|(k, v)| (k[old_shortstatehash.len()..].to_vec(), v)) - .collect::, Box<[u8]>>>() + .collect::, Vec>>() } else { HashMap::new() }; if let Some(state_key) = &new_pdu.state_key { - let mut new_state: HashMap, Box<[u8]>> = old_state; + let mut new_state: HashMap, Vec> = old_state; let mut new_state_key = new_pdu.kind.as_ref().as_bytes().to_vec(); new_state_key.push(0xff); @@ -1209,13 +1209,13 @@ impl Rooms { redacts, } = pdu_builder; // TODO: Make sure this isn't called twice in parallel - let prev_events = self + let prev_events = dbg!(self .get_pdu_leaves(&room_id)? .into_iter() .take(20) - .collect::>(); + .collect::>()); - let create_event = self.room_state_get(&room_id, &EventType::RoomCreate, "")?; + let create_event = dbg!(self.room_state_get(&room_id, &EventType::RoomCreate, ""))?; let create_event_content = create_event .as_ref() @@ -1450,7 +1450,7 @@ impl Rooms { &'a self, user_id: &UserId, room_id: &RoomId, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { self.pdus_since(user_id, room_id, 0) } @@ -1462,7 +1462,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, since: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1491,7 +1491,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, until: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); @@ -1523,7 +1523,7 @@ impl Rooms { user_id: &UserId, room_id: &RoomId, from: u64, - ) -> impl Iterator, PduEvent)>> + 'a { + ) -> impl Iterator, PduEvent)>> + 'a { // Create the first part of the full pdu id let mut prefix = room_id.as_bytes().to_vec(); prefix.push(0xff); diff --git a/src/database/sending.rs b/src/database/sending.rs index ecf0761..c2e1397 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -357,7 +357,7 @@ impl Sending { } #[tracing::instrument(skip(self))] - pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Box<[u8]>) -> Result<()> { + pub fn send_push_pdu(&self, pdu_id: &[u8], senderkey: Vec) -> Result<()> { let mut key = b"$".to_vec(); key.extend_from_slice(&senderkey); key.push(0xff); diff --git a/src/error.rs b/src/error.rs index 501c77d..3091b9d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -35,6 +35,12 @@ pub enum Error { #[from] source: rocksdb::Error, }, + #[cfg(feature = "sqlite")] + #[error("There was a problem with the connection to the sqlite database: {source}")] + SqliteError { + #[from] + source: rusqlite::Error, + }, #[error("Could not generate an image.")] ImageError { #[from] From 0753076e947a48638853b3ef2c7936bc462dca08 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Jul 2021 21:19:49 +0200 Subject: [PATCH 02/28] chutulu is my copilot --- Cargo.lock | 1 + Cargo.toml | 3 +- src/database.rs | 12 +- src/database/abstraction.rs | 1 + src/database/abstraction/sqlite.rs | 282 ++++++++++++++++++++++++----- src/database/media.rs | 10 +- src/database/rooms.rs | 8 +- 7 files changed, 262 insertions(+), 55 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9266c13..9a03542 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -255,6 +255,7 @@ dependencies = [ "image", "jsonwebtoken", "log", + "num_cpus", "lru-cache", "opentelemetry", "opentelemetry-jaeger", diff --git a/Cargo.toml b/Cargo.toml index 17b1841..e7bb3b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,13 +77,14 @@ lru-cache = "0.1.2" rusqlite = { version = "0.25.3", optional = true } parking_lot = { version = "0.11.1", optional = true } crossbeam = { version = "0.8.1", optional = true } +num_cpus = { version = "1.13.0", optional = true } [features] default = ["conduit_bin", "backend_sqlite"] backend_sled = ["sled"] backend_rocksdb = ["rocksdb"] backend_sqlite = ["sqlite"] -sqlite = ["rusqlite", "parking_lot", "crossbeam"] +sqlite = ["rusqlite", "parking_lot", "crossbeam", "num_cpus"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/src/database.rs b/src/database.rs index 76e4ed0..de4f441 100644 --- a/src/database.rs +++ b/src/database.rs @@ -93,6 +93,7 @@ pub type Engine = abstraction::rocksdb::RocksDbEngine; pub type Engine = abstraction::sqlite::SqliteEngine; pub struct Database { + _db: Arc, pub globals: globals::Globals, pub users: users::Users, pub uiaa: uiaa::Uiaa, @@ -132,6 +133,7 @@ impl Database { let (sending_sender, sending_receiver) = mpsc::unbounded(); let db = Arc::new(Self { + _db: builder.clone(), users: users::Users { userid_password: builder.open_tree("userid_password")?, userid_displayname: builder.open_tree("userid_displayname")?, @@ -421,8 +423,12 @@ impl Database { } pub async fn flush(&self) -> Result<()> { - // noop while we don't use sled 1.0 - //self._db.flush_async().await?; - Ok(()) + let start = std::time::Instant::now(); + + let res = self._db.flush(); + + log::debug!("flush: took {:?}", start.elapsed()); + + res } } diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 1ab5dea..fb11ba0 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -15,6 +15,7 @@ pub mod sqlite; pub trait DatabaseEngine: Sized { fn open(config: &Config) -> Result>; fn open_tree(self: &Arc, name: &'static str) -> Result>; + fn flush(self: &Arc) -> Result<()>; } pub trait Tree: Send + Sync { diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 35078f9..be5ce6c 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,19 +1,27 @@ -use std::{future::Future, pin::Pin, sync::Arc, thread}; +use std::{ + future::Future, + ops::Deref, + path::{Path, PathBuf}, + pin::Pin, + sync::{Arc, Weak}, + thread, + time::{Duration, Instant}, +}; use crate::{database::Config, Result}; use super::{DatabaseEngine, Tree}; -use std::{collections::BTreeMap, sync::RwLock}; +use std::collections::BTreeMap; + +use log::debug; use crossbeam::channel::{bounded, Sender as ChannelSender}; -use parking_lot::{Mutex, MutexGuard}; -use rusqlite::{params, Connection, OptionalExtension}; +use parking_lot::{Mutex, MutexGuard, RwLock}; +use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; use tokio::sync::oneshot::Sender; -type SqliteHandle = Arc>; - // const SQL_CREATE_TABLE: &str = // "CREATE TABLE IF NOT EXISTS {} {{ \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL }}"; // const SQL_SELECT: &str = "SELECT value FROM {} WHERE key = ?"; @@ -25,23 +33,137 @@ type SqliteHandle = Arc>; // const SQL_SELECT_ITER_FROM_BACKWARDS: &str = // "SELECT key, value FROM {} WHERE key <= ? ORDER BY DESC"; +struct Pool { + writer: Mutex, + readers: Vec>, + reader_rwlock: RwLock<()>, + path: PathBuf, +} + +pub const MILLI: Duration = Duration::from_millis(1); + +enum HoldingConn<'a> { + FromGuard(MutexGuard<'a, Connection>), + FromOwned(Connection), +} + +impl<'a> Deref for HoldingConn<'a> { + type Target = Connection; + + fn deref(&self) -> &Self::Target { + match self { + HoldingConn::FromGuard(guard) => guard.deref(), + HoldingConn::FromOwned(conn) => conn, + } + } +} + +impl Pool { + fn new>(path: P, num_readers: usize) -> Result { + let writer = Mutex::new(Self::prepare_conn(&path)?); + + let mut readers = Vec::new(); + + for _ in 0..num_readers { + readers.push(Mutex::new(Self::prepare_conn(&path)?)) + } + + Ok(Self { + writer, + readers, + reader_rwlock: RwLock::new(()), + path: path.as_ref().to_path_buf(), + }) + } + + fn prepare_conn>(path: P) -> Result { + let conn = Connection::open(path)?; + + conn.pragma_update(Some(Main), "journal_mode", &"WAL".to_owned())?; + + // conn.pragma_update(Some(Main), "wal_autocheckpoint", &250)?; + + // conn.pragma_update(Some(Main), "wal_checkpoint", &"FULL".to_owned())?; + + conn.pragma_update(Some(Main), "synchronous", &"OFF".to_owned())?; + + Ok(conn) + } + + fn write_lock(&self) -> MutexGuard<'_, Connection> { + self.writer.lock() + } + + fn read_lock(&self) -> HoldingConn<'_> { + let _guard = self.reader_rwlock.read(); + + for r in &self.readers { + if let Some(reader) = r.try_lock() { + return HoldingConn::FromGuard(reader); + } + } + + drop(_guard); + + log::warn!("all readers locked, creating spillover reader..."); + + let spilled = Self::prepare_conn(&self.path).unwrap(); + + return HoldingConn::FromOwned(spilled); + } +} + pub struct SqliteEngine { - handle: SqliteHandle, + pool: Pool, + iterator_lock: RwLock<()>, } impl DatabaseEngine for SqliteEngine { fn open(config: &Config) -> Result> { - let conn = Connection::open(format!("{}/conduit.db", &config.database_path))?; + let pool = Pool::new( + format!("{}/conduit.db", &config.database_path), + num_cpus::get(), + )?; - conn.pragma_update(None, "journal_mode", &"WAL".to_owned())?; + pool.write_lock() + .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; - let handle = Arc::new(Mutex::new(conn)); + let arc = Arc::new(SqliteEngine { + pool, + iterator_lock: RwLock::new(()), + }); + + let weak: Weak = Arc::downgrade(&arc); + + thread::spawn(move || { + let r = crossbeam::channel::tick(Duration::from_secs(60)); + + let weak = weak; + + loop { + let _ = r.recv(); + + if let Some(arc) = Weak::upgrade(&weak) { + log::warn!("wal-trunc: locking..."); + let iterator_guard = arc.iterator_lock.write(); + let read_guard = arc.pool.reader_rwlock.write(); + log::warn!("wal-trunc: locked, flushing..."); + let start = Instant::now(); + arc.flush_wal().unwrap(); + log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); + drop(read_guard); + drop(iterator_guard); + } else { + break; + } + } + }); - Ok(Arc::new(SqliteEngine { handle })) + Ok(arc) } fn open_tree(self: &Arc, name: &str) -> Result> { - self.handle.lock().execute(format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name).as_str(), [])?; + self.pool.write_lock().execute(format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name).as_str(), [])?; Ok(Arc::new(SqliteTable { engine: Arc::clone(self), @@ -49,6 +171,40 @@ impl DatabaseEngine for SqliteEngine { watchers: RwLock::new(BTreeMap::new()), })) } + + fn flush(self: &Arc) -> Result<()> { + self.pool + .write_lock() + .execute_batch( + " + PRAGMA synchronous=FULL; + BEGIN; + DELETE FROM _noop; + INSERT INTO _noop VALUES (1); + COMMIT; + PRAGMA synchronous=OFF; + ", + ) + .map_err(Into::into) + } +} + +impl SqliteEngine { + fn flush_wal(self: &Arc) -> Result<()> { + self.pool + .write_lock() + .execute_batch( + " + PRAGMA synchronous=FULL; PRAGMA wal_checkpoint=TRUNCATE; + BEGIN; + DELETE FROM _noop; + INSERT INTO _noop VALUES (1); + COMMIT; + PRAGMA wal_checkpoint=PASSIVE; PRAGMA synchronous=OFF; + ", + ) + .map_err(Into::into) + } } pub struct SqliteTable { @@ -60,26 +216,17 @@ pub struct SqliteTable { type TupleOfBytes = (Vec, Vec); impl SqliteTable { - fn get_with_guard( - &self, - guard: &MutexGuard<'_, Connection>, - key: &[u8], - ) -> Result>> { + fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { Ok(guard .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? .query_row([key], |row| row.get(0)) .optional()?) } - fn insert_with_guard( - &self, - guard: &MutexGuard<'_, Connection>, - key: &[u8], - value: &[u8], - ) -> Result<()> { + fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { guard.execute( format!( - "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", + "INSERT INTO {} (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value", self.name ) .as_str(), @@ -88,18 +235,18 @@ impl SqliteTable { Ok(()) } - fn _iter_from_thread( - &self, - mutex: Arc>, - f: F, - ) -> Box + Send> + fn _iter_from_thread(&self, f: F) -> Box + Send> where - F: (FnOnce(MutexGuard<'_, Connection>, ChannelSender)) + Send + 'static, + F: (for<'a> FnOnce(&'a Connection, ChannelSender)) + Send + 'static, { let (s, r) = bounded::(5); + let engine = self.engine.clone(); + thread::spawn(move || { - let _ = f(mutex.lock(), s); + let guard = engine.iterator_lock.read(); + let _ = f(&engine.pool.read_lock(), s); + drop(guard); }); Box::new(r.into_iter()) @@ -108,7 +255,7 @@ impl SqliteTable { macro_rules! iter_from_thread { ($self:expr, $sql:expr, $param:expr) => { - $self._iter_from_thread($self.engine.handle.clone(), move |guard, s| { + $self._iter_from_thread(move |guard, s| { let _ = guard .prepare($sql) .unwrap() @@ -122,13 +269,33 @@ macro_rules! iter_from_thread { impl Tree for SqliteTable { fn get(&self, key: &[u8]) -> Result>> { - self.get_with_guard(&self.engine.handle.lock(), key) + let guard = self.engine.pool.read_lock(); + + // let start = Instant::now(); + + let val = self.get_with_guard(&guard, key); + + // debug!("get: took {:?}", start.elapsed()); + // debug!("get key: {:?}", &key) + + val } fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - self.insert_with_guard(&self.engine.handle.lock(), key, value)?; + { + let guard = self.engine.pool.write_lock(); - let watchers = self.watchers.read().unwrap(); + let start = Instant::now(); + + self.insert_with_guard(&guard, key, value)?; + + let elapsed = start.elapsed(); + if elapsed > MILLI { + debug!("insert: took {:012?} : {}", elapsed, &self.name); + } + } + + let watchers = self.watchers.read(); let mut triggered = Vec::new(); for length in 0..=key.len() { @@ -140,7 +307,7 @@ impl Tree for SqliteTable { drop(watchers); if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); + let mut watchers = self.watchers.write(); for prefix in triggered { if let Some(txs) = watchers.remove(prefix) { for tx in txs { @@ -154,10 +321,22 @@ impl Tree for SqliteTable { } fn remove(&self, key: &[u8]) -> Result<()> { - self.engine.handle.lock().execute( + let guard = self.engine.pool.write_lock(); + + let start = Instant::now(); + + guard.execute( format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key], )?; + + let elapsed = start.elapsed(); + + if elapsed > MILLI { + debug!("remove: took {:012?} : {}", elapsed, &self.name); + } + // debug!("remove key: {:?}", &key); + Ok(()) } @@ -201,7 +380,9 @@ impl Tree for SqliteTable { } fn increment(&self, key: &[u8]) -> Result> { - let guard = self.engine.handle.lock(); + let guard = self.engine.pool.write_lock(); + + let start = Instant::now(); let old = self.get_with_guard(&guard, key)?; @@ -210,11 +391,16 @@ impl Tree for SqliteTable { self.insert_with_guard(&guard, key, &new)?; + let elapsed = start.elapsed(); + + if elapsed > MILLI { + debug!("increment: took {:012?} : {}", elapsed, &self.name); + } + // debug!("increment key: {:?}", &key); + Ok(new) } - // TODO: make this use take_while - fn scan_prefix<'a>( &'a self, prefix: Vec, @@ -229,7 +415,10 @@ impl Tree for SqliteTable { // .as_str(), // [prefix] // ) - Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix))) + Box::new( + self.iter_from(&prefix, false) + .take_while(move |(key, _)| key.starts_with(&prefix)), + ) } fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { @@ -237,7 +426,6 @@ impl Tree for SqliteTable { self.watchers .write() - .unwrap() .entry(prefix.to_vec()) .or_default() .push(tx); @@ -249,10 +437,12 @@ impl Tree for SqliteTable { } fn clear(&self) -> Result<()> { - self.engine.handle.lock().execute( - format!("DELETE FROM {}", self.name).as_str(), - [], - )?; + debug!("clear: running"); + self.engine + .pool + .write_lock() + .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; + debug!("clear: ran"); Ok(()) } } diff --git a/src/database/media.rs b/src/database/media.rs index a1fe26e..404a6c0 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -189,7 +189,10 @@ impl Media { original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.push(0xff); - if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { + if let Some((key, _)) = { + /* scoped to explicitly drop iterator */ + self.mediaid_file.scan_prefix(thumbnail_prefix).next() + } { // Using saved thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); @@ -224,7 +227,10 @@ impl Media { content_type, file: file.to_vec(), })) - } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { + } else if let Some((key, _)) = { + /* scoped to explicitly drop iterator */ + self.mediaid_file.scan_prefix(original_prefix).next() + } { // Generate a thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index eb1a924..23cd570 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -733,6 +733,8 @@ impl Rooms { .filter(|user_id| user_id.server_name() == db.globals.server_name()) .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(false)) .filter(|user_id| self.is_joined(&user_id, &pdu.room_id).unwrap_or(false)) + .collect::>() + /* to consume iterator */ { // Don't notify the user of their own events if user == pdu.sender { @@ -1209,13 +1211,13 @@ impl Rooms { redacts, } = pdu_builder; // TODO: Make sure this isn't called twice in parallel - let prev_events = dbg!(self + let prev_events = self .get_pdu_leaves(&room_id)? .into_iter() .take(20) - .collect::>()); + .collect::>(); - let create_event = dbg!(self.room_state_get(&room_id, &EventType::RoomCreate, ""))?; + let create_event = self.room_state_get(&room_id, &EventType::RoomCreate, "")?; let create_event_content = create_event .as_ref() From 9df86c2c1ebc278ac804043851c12687bef9caaf Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sat, 3 Jul 2021 21:30:40 +0200 Subject: [PATCH 03/28] lock update --- Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9a03542..4458d71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -12,7 +12,7 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43bb833f0bf979d8475d38fbf09ed3b8a55e1885fe93ad3f93239fc6a4f17b98" dependencies = [ - "getrandom 0.2.3", + "getrandom 0.2.2", "once_cell", "version_check", ] @@ -255,8 +255,8 @@ dependencies = [ "image", "jsonwebtoken", "log", - "num_cpus", "lru-cache", + "num_cpus", "opentelemetry", "opentelemetry-jaeger", "parking_lot", From 14e6afc45e320fc420273cd5de9ceb8a690a1806 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 01:18:06 +0200 Subject: [PATCH 04/28] remove eldrich being and install good being --- .gitignore | 1 + src/client_server/account.rs | 10 +- src/client_server/alias.rs | 8 +- src/client_server/backup.rs | 30 ++--- src/client_server/config.rs | 10 +- src/client_server/context.rs | 4 +- src/client_server/device.rs | 12 +- src/client_server/directory.rs | 10 +- src/client_server/keys.rs | 14 +-- src/client_server/media.rs | 14 +-- src/client_server/membership.rs | 23 ++-- src/client_server/message.rs | 6 +- src/client_server/presence.rs | 6 +- src/client_server/profile.rs | 12 +- src/client_server/push.rs | 22 ++-- src/client_server/read_marker.rs | 6 +- src/client_server/redact.rs | 4 +- src/client_server/room.rs | 11 +- src/client_server/search.rs | 4 +- src/client_server/session.rs | 8 +- src/client_server/state.rs | 12 +- src/client_server/sync.rs | 16 +-- src/client_server/tag.rs | 8 +- src/client_server/to_device.rs | 4 +- src/client_server/typing.rs | 4 +- src/client_server/user_directory.rs | 4 +- src/database.rs | 185 +++++++++++++++++----------- src/database/abstraction/sqlite.rs | 44 +------ src/database/admin.rs | 54 ++++---- src/database/sending.rs | 43 ++++--- src/lib.rs | 3 +- src/main.rs | 33 ++++- src/ruma_wrapper.rs | 4 +- src/server_server.rs | 41 +++--- 34 files changed, 363 insertions(+), 307 deletions(-) diff --git a/.gitignore b/.gitignore index e2f4e88..1f5f395 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ $RECYCLE.BIN/ # Conduit Rocket.toml conduit.toml +conduit.db # Etc. **/*.rs.bk diff --git a/src/client_server/account.rs b/src/client_server/account.rs index f495e28..dad1b2a 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10; )] #[tracing::instrument(skip(db, body))] pub async fn get_register_available_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { // Validate user id @@ -85,7 +85,7 @@ pub async fn get_register_available_route( )] #[tracing::instrument(skip(db, body))] pub async fn register_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_registration() && !body.from_appservice { @@ -496,7 +496,7 @@ pub async fn register_route( )] #[tracing::instrument(skip(db, body))] pub async fn change_password_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -588,7 +588,7 @@ pub async fn whoami_route(body: Ruma) -> ConduitResult>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index a54bd36..e9e3a23 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use regex::Regex; use ruma::{ api::{ @@ -24,7 +24,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if db.rooms.id_from_alias(&body.room_alias)?.is_some() { @@ -45,7 +45,7 @@ pub async fn create_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; @@ -61,7 +61,7 @@ pub async fn delete_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { get_alias_helper(&db, &body.room_alias).await diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index fcca676..a50e71e 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::backup::{ @@ -21,7 +21,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +40,7 @@ pub async fn create_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -58,7 +58,7 @@ pub async fn update_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_latest_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn get_latest_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -113,7 +113,7 @@ pub async fn get_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -132,7 +132,7 @@ pub async fn delete_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -166,7 +166,7 @@ pub async fn add_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -198,7 +198,7 @@ pub async fn add_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -227,7 +227,7 @@ pub async fn add_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -243,7 +243,7 @@ pub async fn get_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -261,7 +261,7 @@ pub async fn get_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -283,7 +283,7 @@ pub async fn get_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -306,7 +306,7 @@ pub async fn delete_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_sessions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -329,7 +329,7 @@ pub async fn delete_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_session_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/config.rs b/src/client_server/config.rs index 829bf94..ed2628a 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -25,7 +25,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_global_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -60,7 +60,7 @@ pub async fn set_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -92,7 +92,7 @@ pub async fn set_room_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_global_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -119,7 +119,7 @@ pub async fn get_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_account_data_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/context.rs b/src/client_server/context.rs index b86fd0b..7a3c083 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; use std::{convert::TryFrom, sync::Arc}; @@ -12,7 +12,7 @@ use rocket::get; )] #[tracing::instrument(skip(db, body))] pub async fn get_context_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 2441524..361af68 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::{ @@ -20,7 +20,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -40,7 +40,7 @@ pub async fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -59,7 +59,7 @@ pub async fn get_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +85,7 @@ pub async fn update_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -139,7 +139,7 @@ pub async fn delete_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 1b6b1d7..3c96ec1 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ api::{ @@ -35,7 +35,7 @@ use rocket::{get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { get_public_rooms_filtered_helper( @@ -55,7 +55,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let response = get_public_rooms_filtered_helper( @@ -84,7 +84,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_visibility_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -114,7 +114,7 @@ pub async fn set_room_visibility_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_visibility_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_room_visibility::Response { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index f80a329..310fd62 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,5 +1,5 @@ use super::{State, SESSION_ID_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -28,7 +28,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn upload_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -77,7 +77,7 @@ pub async fn upload_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -98,7 +98,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db)?; @@ -114,7 +114,7 @@ pub async fn claim_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signing_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +177,7 @@ pub async fn upload_signing_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signatures_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -238,7 +238,7 @@ pub async fn upload_signatures_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_key_changes_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 0b1fbd7..cd7f714 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,5 +1,7 @@ use super::State; -use crate::{database::media::FileMeta, utils, ConduitResult, Database, Error, Ruma}; +use crate::{ + database::media::FileMeta, database::ReadGuard, utils, ConduitResult, Database, Error, Ruma, +}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, @@ -13,9 +15,7 @@ const MXC_LENGTH: usize = 32; #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[tracing::instrument(skip(db))] -pub async fn get_media_config_route( - db: State<'_, Arc>, -) -> ConduitResult { +pub async fn get_media_config_route(db: ReadGuard) -> ConduitResult { Ok(get_media_config::Response { upload_size: db.globals.max_request_size().into(), } @@ -28,7 +28,7 @@ pub async fn get_media_config_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_content_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!( @@ -66,7 +66,7 @@ pub async fn create_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -119,7 +119,7 @@ pub async fn get_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_thumbnail_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 5c57b68..b3f0a0e 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -1,6 +1,7 @@ use super::State; use crate::{ client_server, + database::ReadGuard, pdu::{PduBuilder, PduEvent}, server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; @@ -44,7 +45,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -81,7 +82,7 @@ pub async fn join_room_by_id_route( )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_or_alias_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -135,7 +136,7 @@ pub async fn join_room_by_id_or_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn leave_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -153,7 +154,7 @@ pub async fn leave_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn invite_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -173,7 +174,7 @@ pub async fn invite_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn kick_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -223,7 +224,7 @@ pub async fn kick_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn ban_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -281,7 +282,7 @@ pub async fn ban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn unban_user_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -330,7 +331,7 @@ pub async fn unban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn forget_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -348,7 +349,7 @@ pub async fn forget_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -369,7 +370,7 @@ pub async fn joined_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_member_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -399,7 +400,7 @@ pub async fn get_member_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_members_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 0d19f34..9764d53 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -23,7 +23,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_message_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn send_message_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_message_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index ce80dfd..69cde56 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; use std::{convert::TryInto, sync::Arc, time::Duration}; @@ -12,7 +12,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_presence_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +53,7 @@ pub async fn set_presence_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_presence_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index 4e9a37b..b7e7998 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -21,7 +21,7 @@ use std::{convert::TryInto, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_displayname_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -108,7 +108,7 @@ pub async fn set_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_displayname_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_display_name::Response { @@ -123,7 +123,7 @@ pub async fn get_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_avatar_url_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -210,7 +210,7 @@ pub async fn set_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_avatar_url_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { Ok(get_avatar_url::Response { @@ -225,7 +225,7 @@ pub async fn get_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_profile_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.users.exists(&body.user_id)? { diff --git a/src/client_server/push.rs b/src/client_server/push.rs index d6f6212..8d4564c 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,7 +24,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrules_all_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -49,7 +49,7 @@ pub async fn get_pushrules_all_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -103,7 +103,7 @@ pub async fn get_pushrule_route( )] #[tracing::instrument(skip(db, req))] pub async fn set_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, req: Ruma>, ) -> ConduitResult { let sender_user = req.sender_user.as_ref().expect("user is authenticated"); @@ -206,7 +206,7 @@ pub async fn set_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_actions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -265,7 +265,7 @@ pub async fn get_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_actions_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -339,7 +339,7 @@ pub async fn set_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_enabled_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -400,7 +400,7 @@ pub async fn get_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_enabled_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -479,7 +479,7 @@ pub async fn set_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_pushrule_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -548,7 +548,7 @@ pub async fn delete_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushers_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -565,7 +565,7 @@ pub async fn get_pushers_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushers_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 837170f..7ab367f 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -20,7 +20,7 @@ use std::{collections::BTreeMap, sync::Arc}; )] #[tracing::instrument(skip(db, body))] pub async fn set_read_marker_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -87,7 +87,7 @@ pub async fn set_read_marker_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_receipt_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index e193082..98e01e7 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Ruma}; use ruma::{ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, @@ -15,7 +15,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn redact_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/room.rs b/src/client_server/room.rs index b33b550..a0b7eb8 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -1,5 +1,8 @@ use super::State; -use crate::{client_server::invite_helper, pdu::PduBuilder, ConduitResult, Database, Error, Ruma}; +use crate::{ + client_server::invite_helper, database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, + Error, Ruma, +}; use log::info; use ruma::{ api::client::{ @@ -24,7 +27,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn create_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -294,7 +297,7 @@ pub async fn create_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -322,7 +325,7 @@ pub async fn get_room_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, _room_id: String, ) -> ConduitResult { diff --git a/src/client_server/search.rs b/src/client_server/search.rs index 5fc64d0..25b0458 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; use std::sync::Arc; @@ -14,7 +14,7 @@ use std::collections::BTreeMap; )] #[tracing::instrument(skip(db, body))] pub async fn search_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/session.rs b/src/client_server/session.rs index dd504f1..ff018d2 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -52,7 +52,7 @@ pub async fn get_login_types_route() -> ConduitResult )] #[tracing::instrument(skip(db, body))] pub async fn login_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { // Validate login method @@ -169,7 +169,7 @@ pub async fn login_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -197,7 +197,7 @@ pub async fn logout_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_all_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/state.rs b/src/client_server/state.rs index be52834..1798536 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -27,7 +27,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -53,7 +53,7 @@ pub async fn send_state_event_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_empty_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -79,7 +79,7 @@ pub async fn send_state_event_for_empty_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -126,7 +126,7 @@ pub async fn get_state_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +177,7 @@ pub async fn get_state_events_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_empty_key_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 69511fa..6df8af8 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Error, Result, Ruma, RumaResponse}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, @@ -35,13 +35,15 @@ use rocket::{get, tokio}; )] #[tracing::instrument(skip(db, body))] pub async fn sync_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let mut rx = match db + let arc_db = Arc::new(db); + + let mut rx = match arc_db .globals .sync_receivers .write() @@ -52,7 +54,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -68,7 +70,7 @@ pub async fn sync_events_route( let (tx, rx) = tokio::sync::watch::channel(None); tokio::spawn(sync_helper_wrapper( - Arc::clone(&db), + Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body.since.clone(), @@ -104,7 +106,7 @@ pub async fn sync_events_route( } pub async fn sync_helper_wrapper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, @@ -146,7 +148,7 @@ pub async fn sync_helper_wrapper( } async fn sync_helper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 2382fe0..cc0d487 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -1,5 +1,5 @@ use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; use ruma::{ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, @@ -15,7 +15,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn update_tag_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -52,7 +52,7 @@ pub async fn update_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_tag_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn delete_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_tags_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index ada0c9a..2814a9d 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use ruma::{ api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, to_device::DeviceIdOrAllDevices, @@ -16,7 +16,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn send_event_to_device_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index a0a5d43..f39ef37 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; use create_typing_event::Typing; use ruma::api::client::r0::typing::create_typing_event; @@ -14,7 +14,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub fn create_typing_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index d7c16d7..ce382b0 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::State; -use crate::{ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; use ruma::api::client::r0::user_directory::search_users; #[cfg(feature = "conduit_bin")] @@ -13,7 +13,7 @@ use rocket::post; )] #[tracing::instrument(skip(db, body))] pub async fn search_users_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { let limit = u64::from(body.limit) as usize; diff --git a/src/database.rs b/src/database.rs index de4f441..fda8a95 100644 --- a/src/database.rs +++ b/src/database.rs @@ -19,16 +19,22 @@ use abstraction::DatabaseEngine; use directories::ProjectDirs; use log::error; use lru_cache::LruCache; -use rocket::futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}; +use rocket::{ + futures::{channel::mpsc, stream::FuturesUnordered, StreamExt}, + outcome::IntoOutcome, + request::{FromRequest, Request}, + try_outcome, State, +}; use ruma::{DeviceId, ServerName, UserId}; use serde::Deserialize; use std::{ collections::HashMap, fs::{self, remove_dir_all}, io::Write, + ops::Deref, sync::{Arc, RwLock}, }; -use tokio::sync::Semaphore; +use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, RwLockReadGuard, Semaphore}; use self::proxy::ProxyConfig; @@ -122,7 +128,7 @@ impl Database { } /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result> { + pub async fn load_or_create(config: Config) -> Result>> { let builder = Engine::open(&config)?; if config.max_request_size < 1024 { @@ -132,7 +138,7 @@ impl Database { let (admin_sender, admin_receiver) = mpsc::unbounded(); let (sending_sender, sending_receiver) = mpsc::unbounded(); - let db = Arc::new(Self { + let db = Arc::new(TokioRwLock::from(Self { _db: builder.clone(), users: users::Users { userid_password: builder.open_tree("userid_password")?, @@ -238,98 +244,105 @@ impl Database { builder.open_tree("server_signingkeys")?, config, )?, - }); - - // MIGRATIONS - // TODO: database versions of new dbs should probably not be 0 - if db.globals.database_version()? < 1 { - for (roomserverid, _) in db.rooms.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); - let servername = match parts.next() { - Some(s) => s, - None => { - error!("Migration: Invalid roomserverid in db."); - continue; - } - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); - serverroomid.extend_from_slice(room_id); + })); + + { + let db = db.read().await; + // MIGRATIONS + // TODO: database versions of new dbs should probably not be 0 + if db.globals.database_version()? < 1 { + for (roomserverid, _) in db.rooms.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xff); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + } + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xff); + serverroomid.extend_from_slice(room_id); - db.rooms.serverroomids.insert(&serverroomid, &[])?; - } + db.rooms.serverroomids.insert(&serverroomid, &[])?; + } - db.globals.bump_database_version(1)?; + db.globals.bump_database_version(1)?; - println!("Migration: 0 -> 1 finished"); - } + println!("Migration: 0 -> 1 finished"); + } - if db.globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.users.userid_password.iter() { - let password = utils::string_from_bytes(&password); + if db.globals.database_version()? < 2 { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.users.userid_password.iter() { + let password = utils::string_from_bytes(&password); - let empty_hashed_password = password.map_or(false, |password| { - argon2::verify_encoded(&password, b"").unwrap_or(false) - }); + let empty_hashed_password = password.map_or(false, |password| { + argon2::verify_encoded(&password, b"").unwrap_or(false) + }); - if empty_hashed_password { - db.users.userid_password.insert(&userid, b"")?; + if empty_hashed_password { + db.users.userid_password.insert(&userid, b"")?; + } } - } - db.globals.bump_database_version(2)?; + db.globals.bump_database_version(2)?; - println!("Migration: 1 -> 2 finished"); - } + println!("Migration: 1 -> 2 finished"); + } - if db.globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.media.mediaid_file.iter() { - if content.len() == 0 { - continue; - } + if db.globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.media.mediaid_file.iter() { + if content.len() == 0 { + continue; + } - let path = db.globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.media.mediaid_file.insert(&key, &[])?; - } + let path = db.globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.media.mediaid_file.insert(&key, &[])?; + } - db.globals.bump_database_version(3)?; + db.globals.bump_database_version(3)?; - println!("Migration: 2 -> 3 finished"); - } + println!("Migration: 2 -> 3 finished"); + } - if db.globals.database_version()? < 4 { - // Add federated users to db as deactivated - for our_user in db.users.iter() { - let our_user = our_user?; - if db.users.is_deactivated(&our_user)? { - continue; - } - for room in db.rooms.rooms_joined(&our_user) { - for user in db.rooms.room_members(&room?) { - let user = user?; - if user.server_name() != db.globals.server_name() { - println!("Migration: Creating user {}", user); - db.users.create(&user, None)?; + if db.globals.database_version()? < 4 { + // Add federated users to db as deactivated + for our_user in db.users.iter() { + let our_user = our_user?; + if db.users.is_deactivated(&our_user)? { + continue; + } + for room in db.rooms.rooms_joined(&our_user) { + for user in db.rooms.room_members(&room?) { + let user = user?; + if user.server_name() != db.globals.server_name() { + println!("Migration: Creating user {}", user); + db.users.create(&user, None)?; + } } } } - } - db.globals.bump_database_version(4)?; + db.globals.bump_database_version(4)?; - println!("Migration: 3 -> 4 finished"); + println!("Migration: 3 -> 4 finished"); + } } + let guard = db.read().await; + // This data is probably outdated - db.rooms.edus.presenceid_presence.clear()?; + guard.rooms.edus.presenceid_presence.clear()?; + + guard.admin.start_handler(Arc::clone(&db), admin_receiver); + guard.sending.start_handler(Arc::clone(&db), sending_receiver); - db.admin.start_handler(Arc::clone(&db), admin_receiver); - db.sending.start_handler(Arc::clone(&db), sending_receiver); + drop(guard); Ok(db) } @@ -431,4 +444,30 @@ impl Database { res } + + pub fn flush_wal(&self) -> Result<()> { + self._db.flush_wal() + } +} + +pub struct ReadGuard(OwnedRwLockReadGuard); + +impl Deref for ReadGuard { + type Target = OwnedRwLockReadGuard; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[cfg(feature = "conduit_bin")] +#[rocket::async_trait] +impl<'r> FromRequest<'r> for ReadGuard { + type Error = (); + + async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome { + let db = try_outcome!(req.guard::>>>().await); + + Ok(ReadGuard(Arc::clone(&db).read_owned().await)).or_forward(()) + } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index be5ce6c..164d985 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -3,7 +3,7 @@ use std::{ ops::Deref, path::{Path, PathBuf}, pin::Pin, - sync::{Arc, Weak}, + sync::Arc, thread, time::{Duration, Instant}, }; @@ -36,7 +36,6 @@ use tokio::sync::oneshot::Sender; struct Pool { writer: Mutex, readers: Vec>, - reader_rwlock: RwLock<()>, path: PathBuf, } @@ -71,7 +70,6 @@ impl Pool { Ok(Self { writer, readers, - reader_rwlock: RwLock::new(()), path: path.as_ref().to_path_buf(), }) } @@ -95,16 +93,12 @@ impl Pool { } fn read_lock(&self) -> HoldingConn<'_> { - let _guard = self.reader_rwlock.read(); - for r in &self.readers { if let Some(reader) = r.try_lock() { return HoldingConn::FromGuard(reader); } } - drop(_guard); - log::warn!("all readers locked, creating spillover reader..."); let spilled = Self::prepare_conn(&self.path).unwrap(); @@ -115,7 +109,6 @@ impl Pool { pub struct SqliteEngine { pool: Pool, - iterator_lock: RwLock<()>, } impl DatabaseEngine for SqliteEngine { @@ -128,36 +121,7 @@ impl DatabaseEngine for SqliteEngine { pool.write_lock() .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; - let arc = Arc::new(SqliteEngine { - pool, - iterator_lock: RwLock::new(()), - }); - - let weak: Weak = Arc::downgrade(&arc); - - thread::spawn(move || { - let r = crossbeam::channel::tick(Duration::from_secs(60)); - - let weak = weak; - - loop { - let _ = r.recv(); - - if let Some(arc) = Weak::upgrade(&weak) { - log::warn!("wal-trunc: locking..."); - let iterator_guard = arc.iterator_lock.write(); - let read_guard = arc.pool.reader_rwlock.write(); - log::warn!("wal-trunc: locked, flushing..."); - let start = Instant::now(); - arc.flush_wal().unwrap(); - log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); - drop(read_guard); - drop(iterator_guard); - } else { - break; - } - } - }); + let arc = Arc::new(SqliteEngine { pool }); Ok(arc) } @@ -190,7 +154,7 @@ impl DatabaseEngine for SqliteEngine { } impl SqliteEngine { - fn flush_wal(self: &Arc) -> Result<()> { + pub fn flush_wal(self: &Arc) -> Result<()> { self.pool .write_lock() .execute_batch( @@ -244,9 +208,7 @@ impl SqliteTable { let engine = self.engine.clone(); thread::spawn(move || { - let guard = engine.iterator_lock.read(); let _ = f(&engine.pool.read_lock(), s); - drop(guard); }); Box::new(r.into_iter()) diff --git a/src/database/admin.rs b/src/database/admin.rs index 7826cfe..f79b789 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -10,6 +10,7 @@ use ruma::{ events::{room::message, EventType}, UserId, }; +use tokio::sync::RwLock; pub enum AdminCommand { RegisterAppservice(serde_yaml::Value), @@ -25,20 +26,22 @@ pub struct Admin { impl Admin { pub fn start_handler( &self, - db: Arc, + db: Arc>, mut receiver: mpsc::UnboundedReceiver, ) { tokio::spawn(async move { // TODO: Use futures when we have long admin commands //let mut futures = FuturesUnordered::new(); - let conduit_user = UserId::try_from(format!("@conduit:{}", db.globals.server_name())) + let guard = db.read().await; + + let conduit_user = UserId::try_from(format!("@conduit:{}", guard.globals.server_name())) .expect("@conduit:server_name is valid"); - let conduit_room = db + let conduit_room = guard .rooms .id_from_alias( - &format!("#admins:{}", db.globals.server_name()) + &format!("#admins:{}", guard.globals.server_name()) .try_into() .expect("#admins:server_name is a valid room alias"), ) @@ -48,24 +51,10 @@ impl Admin { warn!("Conduit instance does not have an #admins room. Logging to that room will not work. Restart Conduit after creating a user to fix this."); } + drop(guard); + let send_message = |message: message::MessageEventContent| { - if let Some(conduit_room) = &conduit_room { - db.rooms - .build_and_append_pdu( - PduBuilder { - event_type: EventType::RoomMessage, - content: serde_json::to_value(message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &db, - ) - .unwrap(); - } + }; loop { @@ -73,10 +62,10 @@ impl Admin { Some(event) = receiver.next() => { match event { AdminCommand::RegisterAppservice(yaml) => { - db.appservice.register_appservice(yaml).unwrap(); // TODO handle error + db.read().await.appservice.register_appservice(yaml).unwrap(); // TODO handle error } AdminCommand::ListAppservices => { - if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::>()) { + if let Ok(appservices) = db.read().await.appservice.iter_ids().map(|ids| ids.collect::>()) { let count = appservices.len(); let output = format!( "Appservices ({}): {}", @@ -89,7 +78,24 @@ impl Admin { } } AdminCommand::SendMessage(message) => { - send_message(message); + if let Some(conduit_room) = &conduit_room { + let guard = db.read().await; + guard.rooms + .build_and_append_pdu( + PduBuilder { + event_type: EventType::RoomMessage, + content: serde_json::to_value(message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &guard, + ) + .unwrap(); + } } } } diff --git a/src/database/sending.rs b/src/database/sending.rs index c2e1397..102fb15 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -30,7 +30,7 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, }; -use tokio::{select, sync::Semaphore}; +use tokio::{select, sync::{Semaphore, RwLock}}; use super::abstraction::Tree; @@ -90,7 +90,7 @@ enum TransactionStatus { } impl Sending { - pub fn start_handler(&self, db: Arc, mut receiver: mpsc::UnboundedReceiver>) { + pub fn start_handler(&self, db: Arc>, mut receiver: mpsc::UnboundedReceiver>) { tokio::spawn(async move { let mut futures = FuturesUnordered::new(); @@ -98,8 +98,11 @@ impl Sending { // Retry requests we could not finish yet let mut initial_transactions = HashMap::>::new(); + + let guard = db.read().await; + for (key, outgoing_kind, event) in - db.sending + guard.sending .servercurrentevents .iter() .filter_map(|(key, _)| { @@ -117,17 +120,19 @@ impl Sending { "Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event ); - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); continue; } entry.push(event); } + drop(guard); + for (outgoing_kind, events) in initial_transactions { current_transaction_status .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events, &db)); + futures.push(Self::handle_events(outgoing_kind.clone(), events, Arc::clone(&db))); } loop { @@ -135,15 +140,17 @@ impl Sending { Some(response) = futures.next() => { match response { Ok(outgoing_kind) => { + let guard = db.read().await; + let prefix = outgoing_kind.get_prefix(); - for (key, _) in db.sending.servercurrentevents + for (key, _) in guard.sending.servercurrentevents .scan_prefix(prefix.clone()) { - db.sending.servercurrentevents.remove(&key).unwrap(); + guard.sending.servercurrentevents.remove(&key).unwrap(); } // Find events that have been added since starting the last request - let new_events = db.sending.servernamepduids + let new_events = guard.sending.servernamepduids .scan_prefix(prefix.clone()) .map(|(k, _)| { SendingEventType::Pdu(k[prefix.len()..].to_vec()) @@ -161,17 +168,19 @@ impl Sending { SendingEventType::Pdu(b) | SendingEventType::Edu(b) => { current_key.extend_from_slice(&b); - db.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); - db.sending.servernamepduids.remove(¤t_key).unwrap(); + guard.sending.servercurrentevents.insert(¤t_key, &[]).unwrap(); + guard.sending.servernamepduids.remove(¤t_key).unwrap(); } } } + drop(guard); + futures.push( Self::handle_events( outgoing_kind.clone(), new_events, - &db, + Arc::clone(&db), ) ); } else { @@ -192,13 +201,15 @@ impl Sending { }, Some(key) = receiver.next() => { if let Ok((outgoing_kind, event)) = Self::parse_servercurrentevent(&key) { + let guard = db.read().await; + if let Ok(Some(events)) = Self::select_events( &outgoing_kind, vec![(event, key)], &mut current_transaction_status, - &db + &guard ) { - futures.push(Self::handle_events(outgoing_kind, events, &db)); + futures.push(Self::handle_events(outgoing_kind, events, Arc::clone(&db))); } } } @@ -403,8 +414,10 @@ impl Sending { async fn handle_events( kind: OutgoingKind, events: Vec, - db: &Database, + db: Arc>, ) -> std::result::Result { + let db = db.read().await; + match &kind { OutgoingKind::Appservice(server) => { let mut pdu_jsons = Vec::new(); @@ -543,7 +556,7 @@ impl Sending { &pusher, rules_for_user, &pdu, - db, + &db, ) .await .map(|_response| kind.clone()) diff --git a/src/lib.rs b/src/lib.rs index 50ca6ea..c1f0c4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,8 @@ pub use error::{Error, Result}; pub use pdu::PduEvent; pub use rocket::Config; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; -use std::ops::Deref; +use std::{ops::Deref, sync::Arc}; +use tokio::sync::RwLock; pub struct State<'r, T: Send + Sync + 'static>(pub &'r T); diff --git a/src/main.rs b/src/main.rs index 99d4560..826f182 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ mod pdu; mod ruma_wrapper; mod utils; -use std::sync::Arc; +use std::{sync::{Arc, Weak}, time::{Duration, Instant}}; use database::Config; pub use database::Database; @@ -30,10 +30,14 @@ use rocket::{ }, routes, Request, }; +use tokio::{ + sync::RwLock, + time::{interval, Interval}, +}; use tracing::span; use tracing_subscriber::{prelude::*, Registry}; -fn setup_rocket(config: Figment, data: Arc) -> rocket::Rocket { +fn setup_rocket(config: Figment, data: Arc>) -> rocket::Rocket { rocket::custom(config) .manage(data) .mount( @@ -201,6 +205,31 @@ async fn main() { .await .expect("config is valid"); + { + let weak: Weak> = Arc::downgrade(&db); + + tokio::spawn(async { + let weak = weak; + + let mut i = interval(Duration::from_secs(10)); + + loop { + i.tick().await; + + if let Some(arc) = Weak::upgrade(&weak) { + log::warn!("wal-trunc: locking..."); + let guard = arc.write().await; + log::warn!("wal-trunc: locked, flushing..."); + let start = Instant::now(); + guard.flush_wal(); + log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); + } else { + break; + } + } + }); + } + if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() .with_service_name("conduit") diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 8c22f79..647d25e 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,4 +1,4 @@ -use crate::Error; +use crate::{database::ReadGuard, Error}; use ruma::{ api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, @@ -49,7 +49,7 @@ where async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome { let metadata = T::Incoming::METADATA; let db = request - .guard::>>() + .guard::() .await .expect("database was loaded"); diff --git a/src/server_server.rs b/src/server_server.rs index 2bcfd2b..b3e8f1b 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,5 +1,6 @@ use crate::{ client_server::{self, claim_keys_helper, get_keys_helper}, + database::ReadGuard, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, }; use get_profile_information::v1::ProfileField; @@ -431,9 +432,7 @@ pub async fn request_well_known( #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[tracing::instrument(skip(db))] -pub fn get_server_version_route( - db: State<'_, Arc>, -) -> ConduitResult { +pub fn get_server_version_route(db: ReadGuard) -> ConduitResult { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -450,7 +449,7 @@ pub fn get_server_version_route( // Response type for this endpoint is Json because we need to calculate a signature for the response #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_route(db: ReadGuard) -> Json { if !db.globals.allow_federation() { // TODO: Use proper types return Json("Federation is disabled.".to_owned()); @@ -497,7 +496,7 @@ pub fn get_server_keys_route(db: State<'_, Arc>) -> Json { #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json { +pub fn get_server_keys_deprecated_route(db: ReadGuard) -> Json { get_server_keys_route(db) } @@ -507,7 +506,7 @@ pub fn get_server_keys_deprecated_route(db: State<'_, Arc>) -> Json>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -551,7 +550,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -595,7 +594,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_transaction_message_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1674,7 +1673,7 @@ pub(crate) fn append_incoming_pdu( )] #[tracing::instrument(skip(db, body))] pub fn get_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1699,7 +1698,7 @@ pub fn get_event_route( )] #[tracing::instrument(skip(db, body))] pub fn get_missing_events_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1748,7 +1747,7 @@ pub fn get_missing_events_route( )] #[tracing::instrument(skip(db, body))] pub fn get_event_authorization_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1792,7 +1791,7 @@ pub fn get_event_authorization_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1855,7 +1854,7 @@ pub fn get_room_state_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_ids_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1907,7 +1906,7 @@ pub fn get_room_state_ids_route( )] #[tracing::instrument(skip(db, body))] pub fn create_join_event_template_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2076,7 +2075,7 @@ pub fn create_join_event_template_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_join_event_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2184,7 +2183,7 @@ pub async fn create_join_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_invite_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2289,7 +2288,7 @@ pub async fn create_invite_route( )] #[tracing::instrument(skip(db, body))] pub fn get_devices_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2329,7 +2328,7 @@ pub fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_information_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2357,7 +2356,7 @@ pub fn get_room_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_profile_information_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2391,7 +2390,7 @@ pub fn get_profile_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2419,7 +2418,7 @@ pub fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: State<'_, Arc>, + db: ReadGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { From a55dec9035bbc05b35ef7b3e5f24cfb6df5dc62c Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 02:03:46 +0200 Subject: [PATCH 05/28] add better performance around syncs --- src/client_server/sync.rs | 2 ++ src/main.rs | 15 +++++++++++---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 6df8af8..ddc4870 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -144,6 +144,8 @@ pub async fn sync_helper_wrapper( } } + drop(db); + let _ = tx.send(Some(r.map(|(r, _)| r.into()))); } diff --git a/src/main.rs b/src/main.rs index 826f182..44444d6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -32,7 +32,7 @@ use rocket::{ }; use tokio::{ sync::RwLock, - time::{interval, Interval}, + time::{interval, timeout}, }; use tracing::span; use tracing_subscriber::{prelude::*, Registry}; @@ -211,18 +211,25 @@ async fn main() { tokio::spawn(async { let weak = weak; - let mut i = interval(Duration::from_secs(10)); + let mut i = interval(Duration::from_secs(60)); loop { i.tick().await; if let Some(arc) = Weak::upgrade(&weak) { log::warn!("wal-trunc: locking..."); - let guard = arc.write().await; + let guard = { + if let Ok(guard) = timeout(Duration::from_secs(5), arc.write()).await { + guard + } else { + log::warn!("wal-trunc: lock failed in timeout, canceled."); + continue; + } + }; log::warn!("wal-trunc: locked, flushing..."); let start = Instant::now(); guard.flush_wal(); - log::warn!("wal-trunc: locked, flushed in {:?}", start.elapsed()); + log::warn!("wal-trunc: flushed in {:?}", start.elapsed()); } else { break; } From 5ec0be2b41005a6f1ed39383cae0e391ba55dbce Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 02:15:04 +0200 Subject: [PATCH 06/28] fmt --- src/database.rs | 4 +++- src/database/admin.rs | 9 ++++----- src/database/sending.rs | 20 ++++++++++++++++---- src/main.rs | 5 ++++- 4 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/database.rs b/src/database.rs index fda8a95..7c60b10 100644 --- a/src/database.rs +++ b/src/database.rs @@ -340,7 +340,9 @@ impl Database { guard.rooms.edus.presenceid_presence.clear()?; guard.admin.start_handler(Arc::clone(&db), admin_receiver); - guard.sending.start_handler(Arc::clone(&db), sending_receiver); + guard + .sending + .start_handler(Arc::clone(&db), sending_receiver); drop(guard); diff --git a/src/database/admin.rs b/src/database/admin.rs index f79b789..fcc13c3 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -35,8 +35,9 @@ impl Admin { let guard = db.read().await; - let conduit_user = UserId::try_from(format!("@conduit:{}", guard.globals.server_name())) - .expect("@conduit:server_name is valid"); + let conduit_user = + UserId::try_from(format!("@conduit:{}", guard.globals.server_name())) + .expect("@conduit:server_name is valid"); let conduit_room = guard .rooms @@ -53,9 +54,7 @@ impl Admin { drop(guard); - let send_message = |message: message::MessageEventContent| { - - }; + let send_message = |message: message::MessageEventContent| {}; loop { tokio::select! { diff --git a/src/database/sending.rs b/src/database/sending.rs index 102fb15..7c9cf64 100644 --- a/src/database/sending.rs +++ b/src/database/sending.rs @@ -30,7 +30,10 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, ServerName, UInt, UserId, }; -use tokio::{select, sync::{Semaphore, RwLock}}; +use tokio::{ + select, + sync::{RwLock, Semaphore}, +}; use super::abstraction::Tree; @@ -90,7 +93,11 @@ enum TransactionStatus { } impl Sending { - pub fn start_handler(&self, db: Arc>, mut receiver: mpsc::UnboundedReceiver>) { + pub fn start_handler( + &self, + db: Arc>, + mut receiver: mpsc::UnboundedReceiver>, + ) { tokio::spawn(async move { let mut futures = FuturesUnordered::new(); @@ -102,7 +109,8 @@ impl Sending { let guard = db.read().await; for (key, outgoing_kind, event) in - guard.sending + guard + .sending .servercurrentevents .iter() .filter_map(|(key, _)| { @@ -132,7 +140,11 @@ impl Sending { for (outgoing_kind, events) in initial_transactions { current_transaction_status .insert(outgoing_kind.get_prefix(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events, Arc::clone(&db))); + futures.push(Self::handle_events( + outgoing_kind.clone(), + events, + Arc::clone(&db), + )); } loop { diff --git a/src/main.rs b/src/main.rs index 44444d6..9968b87 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,10 @@ mod pdu; mod ruma_wrapper; mod utils; -use std::{sync::{Arc, Weak}, time::{Duration, Instant}}; +use std::{ + sync::{Arc, Weak}, + time::{Duration, Instant}, +}; use database::Config; pub use database::Database; From e5a26de606b8bc2ecf84e3bf149604f59d76f21a Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 13:14:01 +0200 Subject: [PATCH 07/28] misc cleanup --- src/client_server/account.rs | 6 ++-- src/client_server/alias.rs | 3 -- src/client_server/backup.rs | 5 +-- src/client_server/config.rs | 5 +-- src/client_server/context.rs | 5 ++- src/client_server/device.rs | 5 +-- src/client_server/directory.rs | 3 -- src/client_server/keys.rs | 7 ++-- src/client_server/media.rs | 7 ++-- src/client_server/membership.rs | 1 - src/client_server/message.rs | 4 +-- src/client_server/mod.rs | 4 +-- src/client_server/presence.rs | 5 ++- src/client_server/profile.rs | 5 ++- src/client_server/push.rs | 5 +-- src/client_server/read_marker.rs | 5 ++- src/client_server/redact.rs | 4 +-- src/client_server/room.rs | 6 ++-- src/client_server/search.rs | 4 +-- src/client_server/session.rs | 6 ++-- src/client_server/state.rs | 3 -- src/client_server/sync.rs | 1 - src/client_server/tag.rs | 5 ++- src/client_server/to_device.rs | 5 +-- src/client_server/typing.rs | 5 +-- src/client_server/user_directory.rs | 5 +-- src/database.rs | 9 +++-- src/database/admin.rs | 53 ++++++++++++++++------------- src/lib.rs | 3 +- src/main.rs | 7 ++-- src/ruma_wrapper.rs | 5 ++- src/server_server.rs | 2 +- 32 files changed, 77 insertions(+), 121 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index dad1b2a..d76816b 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,7 +1,7 @@ -use std::{collections::BTreeMap, convert::TryInto, sync::Arc}; +use std::{collections::BTreeMap, convert::TryInto}; -use super::{State, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index e9e3a23..440b505 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use super::State; use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; use regex::Regex; use ruma::{ diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index a50e71e..8412778 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::backup::{ diff --git a/src/client_server/config.rs b/src/client_server/config.rs index ed2628a..da504e1 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 7a3c083..7d857e1 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,7 +1,6 @@ -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; -use std::{convert::TryFrom, sync::Arc}; +use std::convert::TryFrom; #[cfg(feature = "conduit_bin")] use rocket::get; diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 361af68..4d8d16c 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::{ diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 3c96ec1..5cfc458 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use super::State; use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 310fd62..39cf7be 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,4 +1,4 @@ -use super::{State, SESSION_ID_LENGTH}; +use super::SESSION_ID_LENGTH; use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ @@ -14,10 +14,7 @@ use ruma::{ encryption::UnsignedDeviceInfo, DeviceId, DeviceKeyAlgorithm, UserId, }; -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, -}; +use std::collections::{BTreeMap, HashSet}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; diff --git a/src/client_server/media.rs b/src/client_server/media.rs index cd7f714..588e5f7 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,15 +1,12 @@ -use super::State; -use crate::{ - database::media::FileMeta, database::ReadGuard, utils, ConduitResult, Database, Error, Ruma, -}; +use crate::{database::media::FileMeta, database::ReadGuard, utils, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, }; +use std::convert::TryInto; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; -use std::{convert::TryInto, sync::Arc}; const MXC_LENGTH: usize = 32; diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index b3f0a0e..8743940 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -1,4 +1,3 @@ -use super::State; use crate::{ client_server, database::ReadGuard, diff --git a/src/client_server/message.rs b/src/client_server/message.rs index 9764d53..d439535 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -11,7 +10,6 @@ use ruma::{ use std::{ collections::BTreeMap, convert::{TryFrom, TryInto}, - sync::Arc, }; #[cfg(feature = "conduit_bin")] diff --git a/src/client_server/mod.rs b/src/client_server/mod.rs index 825dbbb..f211a57 100644 --- a/src/client_server/mod.rs +++ b/src/client_server/mod.rs @@ -64,9 +64,7 @@ pub use voip::*; use super::State; #[cfg(feature = "conduit_bin")] use { - crate::ConduitResult, - rocket::{options, State}, - ruma::api::client::r0::to_device::send_event_to_device, + crate::ConduitResult, rocket::options, ruma::api::client::r0::to_device::send_event_to_device, }; pub const DEVICE_ID_LENGTH: usize = 10; diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index 69cde56..c96e62b 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,7 +1,6 @@ -use super::State; -use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; -use std::{convert::TryInto, sync::Arc, time::Duration}; +use std::{convert::TryInto, time::Duration}; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index b7e7998..f516f91 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -10,10 +9,10 @@ use ruma::{ events::EventType, serde::Raw, }; +use std::convert::TryInto; #[cfg(feature = "conduit_bin")] use rocket::{get, put}; -use std::{convert::TryInto, sync::Arc}; #[cfg_attr( feature = "conduit_bin", diff --git a/src/client_server/push.rs b/src/client_server/push.rs index 8d4564c..e718287 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 7ab367f..64f8e26 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,5 +1,4 @@ -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -9,10 +8,10 @@ use ruma::{ receipt::ReceiptType, MilliSecondsSinceUnixEpoch, }; +use std::collections::BTreeMap; #[cfg(feature = "conduit_bin")] use rocket::post; -use std::{collections::BTreeMap, sync::Arc}; #[cfg_attr( feature = "conduit_bin", diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index 98e01e7..af1b242 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,10 +1,8 @@ -use super::State; -use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Ruma}; use ruma::{ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, }; -use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::put; diff --git a/src/client_server/room.rs b/src/client_server/room.rs index a0b7eb8..c01c5f8 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -1,7 +1,5 @@ -use super::State; use crate::{ - client_server::invite_helper, database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, - Error, Ruma, + client_server::invite_helper, database::ReadGuard, pdu::PduBuilder, ConduitResult, Error, Ruma, }; use log::info; use ruma::{ @@ -16,7 +14,7 @@ use ruma::{ serde::Raw, RoomAliasId, RoomId, RoomVersionId, }; -use std::{cmp::max, collections::BTreeMap, convert::TryFrom, sync::Arc}; +use std::{cmp::max, collections::BTreeMap, convert::TryFrom}; #[cfg(feature = "conduit_bin")] use rocket::{get, post}; diff --git a/src/client_server/search.rs b/src/client_server/search.rs index 25b0458..cfe1345 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,7 +1,5 @@ -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; -use std::sync::Arc; #[cfg(feature = "conduit_bin")] use rocket::post; diff --git a/src/client_server/session.rs b/src/client_server/session.rs index ff018d2..844ef0c 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,7 +1,5 @@ -use std::sync::Arc; - -use super::{State, DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Ruma}; +use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; +use crate::{database::ReadGuard, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ diff --git a/src/client_server/state.rs b/src/client_server/state.rs index 1798536..dea52aa 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,6 +1,3 @@ -use std::sync::Arc; - -use super::State; use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index ddc4870..09ac8f0 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,4 +1,3 @@ -use super::State; use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index cc0d487..8fa1d05 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -1,10 +1,9 @@ -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Ruma}; use ruma::{ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, }; -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; #[cfg(feature = "conduit_bin")] use rocket::{delete, get, put}; diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index 2814a9d..761a359 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, to_device::DeviceIdOrAllDevices, diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index f39ef37..8f3a643 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, utils, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, utils, ConduitResult, Ruma}; use create_typing_event::Typing; use ruma::api::client::r0::typing::create_typing_event; diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index ce382b0..226fe6e 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,7 +1,4 @@ -use std::sync::Arc; - -use super::State; -use crate::{database::ReadGuard, ConduitResult, Database, Ruma}; +use crate::{database::ReadGuard, ConduitResult, Ruma}; use ruma::api::client::r0::user_directory::search_users; #[cfg(feature = "conduit_bin")] diff --git a/src/database.rs b/src/database.rs index 7c60b10..9986651 100644 --- a/src/database.rs +++ b/src/database.rs @@ -34,7 +34,7 @@ use std::{ ops::Deref, sync::{Arc, RwLock}, }; -use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, RwLockReadGuard, Semaphore}; +use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use self::proxy::ProxyConfig; @@ -462,7 +462,6 @@ impl Deref for ReadGuard { } } -#[cfg(feature = "conduit_bin")] #[rocket::async_trait] impl<'r> FromRequest<'r> for ReadGuard { type Error = (); @@ -473,3 +472,9 @@ impl<'r> FromRequest<'r> for ReadGuard { Ok(ReadGuard(Arc::clone(&db).read_owned().await)).or_forward(()) } } + +impl Into for OwnedRwLockReadGuard { + fn into(self) -> ReadGuard { + ReadGuard(self) + } +} diff --git a/src/database/admin.rs b/src/database/admin.rs index fcc13c3..cd5fa84 100644 --- a/src/database/admin.rs +++ b/src/database/admin.rs @@ -10,7 +10,7 @@ use ruma::{ events::{room::message, EventType}, UserId, }; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, RwLockReadGuard}; pub enum AdminCommand { RegisterAppservice(serde_yaml::Value), @@ -54,47 +54,52 @@ impl Admin { drop(guard); - let send_message = |message: message::MessageEventContent| {}; + let send_message = + |message: message::MessageEventContent, guard: RwLockReadGuard<'_, Database>| { + if let Some(conduit_room) = &conduit_room { + guard + .rooms + .build_and_append_pdu( + PduBuilder { + event_type: EventType::RoomMessage, + content: serde_json::to_value(message) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &guard, + ) + .unwrap(); + } + }; loop { tokio::select! { Some(event) = receiver.next() => { + let guard = db.read().await; + match event { AdminCommand::RegisterAppservice(yaml) => { - db.read().await.appservice.register_appservice(yaml).unwrap(); // TODO handle error + guard.appservice.register_appservice(yaml).unwrap(); // TODO handle error } AdminCommand::ListAppservices => { - if let Ok(appservices) = db.read().await.appservice.iter_ids().map(|ids| ids.collect::>()) { + if let Ok(appservices) = guard.appservice.iter_ids().map(|ids| ids.collect::>()) { let count = appservices.len(); let output = format!( "Appservices ({}): {}", count, appservices.into_iter().filter_map(|r| r.ok()).collect::>().join(", ") ); - send_message(message::MessageEventContent::text_plain(output)); + send_message(message::MessageEventContent::text_plain(output), guard); } else { - send_message(message::MessageEventContent::text_plain("Failed to get appservices.")); + send_message(message::MessageEventContent::text_plain("Failed to get appservices."), guard); } } AdminCommand::SendMessage(message) => { - if let Some(conduit_room) = &conduit_room { - let guard = db.read().await; - guard.rooms - .build_and_append_pdu( - PduBuilder { - event_type: EventType::RoomMessage, - content: serde_json::to_value(message) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &guard, - ) - .unwrap(); - } + send_message(message, guard) } } } diff --git a/src/lib.rs b/src/lib.rs index c1f0c4b..50ca6ea 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,8 +15,7 @@ pub use error::{Error, Result}; pub use pdu::PduEvent; pub use rocket::Config; pub use ruma_wrapper::{ConduitResult, Ruma, RumaResponse}; -use std::{ops::Deref, sync::Arc}; -use tokio::sync::RwLock; +use std::ops::Deref; pub struct State<'r, T: Send + Sync + 'static>(pub &'r T); diff --git a/src/main.rs b/src/main.rs index 9968b87..5598728 100644 --- a/src/main.rs +++ b/src/main.rs @@ -231,8 +231,11 @@ async fn main() { }; log::warn!("wal-trunc: locked, flushing..."); let start = Instant::now(); - guard.flush_wal(); - log::warn!("wal-trunc: flushed in {:?}", start.elapsed()); + if let Err(e) = guard.flush_wal() { + log::warn!("wal-trunc: errored: {}", e); + } else { + log::warn!("wal-trunc: flushed in {:?}", start.elapsed()); + } } else { break; } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index 647d25e..d06f224 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -9,7 +9,7 @@ use std::ops::Deref; #[cfg(feature = "conduit_bin")] use { - crate::{server_server, Database}, + crate::server_server, log::{debug, warn}, rocket::{ data::{self, ByteUnit, Data, FromData}, @@ -17,13 +17,12 @@ use { outcome::Outcome::*, response::{self, Responder}, tokio::io::AsyncReadExt, - Request, State, + Request, }, ruma::api::{AuthScheme, IncomingRequest}, std::collections::BTreeMap, std::convert::TryFrom, std::io::Cursor, - std::sync::Arc, }; /// This struct converts rocket requests into ruma structs by converting them into http requests diff --git a/src/server_server.rs b/src/server_server.rs index b3e8f1b..306afd2 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -7,7 +7,7 @@ use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION, HOST}; use log::{debug, error, info, trace, warn}; use regex::Regex; -use rocket::{response::content::Json, State}; +use rocket::response::content::Json; use ruma::{ api::{ client::error::{Error as RumaError, ErrorKind}, From f81018ab2d7e514a861ac7faa230254b5900d7ce Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 4 Jul 2021 13:30:47 +0200 Subject: [PATCH 08/28] reverse iterator funk --- src/database/media.rs | 10 ++-------- src/database/rooms.rs | 2 -- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/src/database/media.rs b/src/database/media.rs index 404a6c0..a1fe26e 100644 --- a/src/database/media.rs +++ b/src/database/media.rs @@ -189,10 +189,7 @@ impl Media { original_prefix.extend_from_slice(&0_u32.to_be_bytes()); // Height = 0 if it's not a thumbnail original_prefix.push(0xff); - if let Some((key, _)) = { - /* scoped to explicitly drop iterator */ - self.mediaid_file.scan_prefix(thumbnail_prefix).next() - } { + if let Some((key, _)) = self.mediaid_file.scan_prefix(thumbnail_prefix).next() { // Using saved thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); @@ -227,10 +224,7 @@ impl Media { content_type, file: file.to_vec(), })) - } else if let Some((key, _)) = { - /* scoped to explicitly drop iterator */ - self.mediaid_file.scan_prefix(original_prefix).next() - } { + } else if let Some((key, _)) = self.mediaid_file.scan_prefix(original_prefix).next() { // Generate a thumbnail let path = globals.get_media_file(&key); let mut file = Vec::new(); diff --git a/src/database/rooms.rs b/src/database/rooms.rs index 23cd570..7b64c46 100644 --- a/src/database/rooms.rs +++ b/src/database/rooms.rs @@ -733,8 +733,6 @@ impl Rooms { .filter(|user_id| user_id.server_name() == db.globals.server_name()) .filter(|user_id| !db.users.is_deactivated(user_id).unwrap_or(false)) .filter(|user_id| self.is_joined(&user_id, &pdu.room_id).unwrap_or(false)) - .collect::>() - /* to consume iterator */ { // Don't notify the user of their own events if user == pdu.sender { From dc5f1f41fd4751b3de4f13a44c74ce090d0abc8e Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 7 Jul 2021 14:04:11 +0200 Subject: [PATCH 09/28] some more fixes to allow sled to work --- src/database.rs | 1 + src/database/abstraction/sled.rs | 8 ++++++-- src/main.rs | 18 ++++++++++-------- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/database.rs b/src/database.rs index 9986651..bd58299 100644 --- a/src/database.rs +++ b/src/database.rs @@ -447,6 +447,7 @@ impl Database { res } + #[cfg(feature = "sqlite")] pub fn flush_wal(&self) -> Result<()> { self._db.flush_wal() } diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index 2f3fb34..557e8a0 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -23,6 +23,10 @@ impl DatabaseEngine for SledEngine { fn open_tree(self: &Arc, name: &'static str) -> Result> { Ok(Arc::new(SledEngineTree(self.0.open_tree(name)?))) } + + fn flush(self: &Arc) -> Result<()> { + Ok(()) // noop + } } impl Tree for SledEngineTree { @@ -40,7 +44,7 @@ impl Tree for SledEngineTree { Ok(()) } - fn iter<'a>(&'a self) -> Box, Vec)> + Send + Sync + 'a> { + fn iter<'a>(&'a self) -> Box, Vec)> + Send + 'a> { Box::new( self.0 .iter() @@ -58,7 +62,7 @@ impl Tree for SledEngineTree { &self, from: &[u8], backwards: bool, - ) -> Box, Vec)>> { + ) -> Box, Vec)> + Send> { let iter = if backwards { self.0.range(..from) } else { diff --git a/src/main.rs b/src/main.rs index 5598728..5c5ea84 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,10 +12,7 @@ mod pdu; mod ruma_wrapper; mod utils; -use std::{ - sync::{Arc, Weak}, - time::{Duration, Instant}, -}; +use std::sync::Arc; use database::Config; pub use database::Database; @@ -33,10 +30,7 @@ use rocket::{ }, routes, Request, }; -use tokio::{ - sync::RwLock, - time::{interval, timeout}, -}; +use tokio::sync::RwLock; use tracing::span; use tracing_subscriber::{prelude::*, Registry}; @@ -208,7 +202,15 @@ async fn main() { .await .expect("config is valid"); + #[cfg(feature = "sqlite")] { + use tokio::time::{interval, timeout}; + + use std::{ + sync::Weak, + time::{Duration, Instant}, + }; + let weak: Weak> = Arc::downgrade(&db); tokio::spawn(async { From 0c23874194101f2ce2e812fb258025ce66e3e124 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 7 Jul 2021 20:36:41 +0200 Subject: [PATCH 10/28] add config and optimise --- Cargo.lock | 1 + Cargo.toml | 4 +- src/database.rs | 81 ++++++++++++++++++++++++++++++ src/database/abstraction/sqlite.rs | 34 +++++++++---- src/main.rs | 42 +--------------- 5 files changed, 109 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4458d71..0d73e5e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1178,6 +1178,7 @@ version = "0.22.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" dependencies = [ + "cc", "pkg-config", "vcpkg", ] diff --git a/Cargo.toml b/Cargo.toml index e7bb3b8..7edf641 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,7 +74,7 @@ tracing-opentelemetry = "0.11.0" opentelemetry-jaeger = "0.11.0" pretty_env_logger = "0.4.0" lru-cache = "0.1.2" -rusqlite = { version = "0.25.3", optional = true } +rusqlite = { version = "0.25.3", optional = true, features = ["bundled"] } parking_lot = { version = "0.11.1", optional = true } crossbeam = { version = "0.8.1", optional = true } num_cpus = { version = "1.13.0", optional = true } @@ -84,7 +84,7 @@ default = ["conduit_bin", "backend_sqlite"] backend_sled = ["sled"] backend_rocksdb = ["rocksdb"] backend_sqlite = ["sqlite"] -sqlite = ["rusqlite", "parking_lot", "crossbeam", "num_cpus"] +sqlite = ["rusqlite", "parking_lot", "crossbeam", "num_cpus", "tokio/signal"] conduit_bin = [] # TODO: add rocket to this when it is optional [[bin]] diff --git a/src/database.rs b/src/database.rs index bd58299..e5bb865 100644 --- a/src/database.rs +++ b/src/database.rs @@ -44,6 +44,16 @@ pub struct Config { database_path: String, #[serde(default = "default_cache_capacity")] cache_capacity: u32, + #[serde(default = "default_sqlite_cache_kib")] + sqlite_cache_kib: u32, + #[serde(default = "default_sqlite_read_pool_size")] + sqlite_read_pool_size: usize, + #[serde(default = "false_fn")] + sqlite_wal_clean_timer: bool, + #[serde(default = "default_sqlite_wal_clean_second_interval")] + sqlite_wal_clean_second_interval: u32, + #[serde(default = "default_sqlite_wal_clean_second_timeout")] + sqlite_wal_clean_second_timeout: u32, #[serde(default = "default_max_request_size")] max_request_size: u32, #[serde(default = "default_max_concurrent_requests")] @@ -77,6 +87,22 @@ fn default_cache_capacity() -> u32 { 1024 * 1024 * 1024 } +fn default_sqlite_cache_kib() -> u32 { + 2000 +} + +fn default_sqlite_read_pool_size() -> usize { + num_cpus::get().max(1) +} + +fn default_sqlite_wal_clean_second_interval() -> u32 { + 60 +} + +fn default_sqlite_wal_clean_second_timeout() -> u32 { + 2 +} + fn default_max_request_size() -> u32 { 20 * 1024 * 1024 // Default to 20 MB } @@ -451,6 +477,61 @@ impl Database { pub fn flush_wal(&self) -> Result<()> { self._db.flush_wal() } + + #[cfg(feature = "sqlite")] + pub async fn start_wal_clean_task(lock: &Arc>, config: &Config) { + use tokio::{ + signal::unix::{signal, SignalKind}, + time::{interval, timeout}, + }; + + use std::{ + sync::Weak, + time::{Duration, Instant}, + }; + + let weak: Weak> = Arc::downgrade(&lock); + + let lock_timeout = Duration::from_secs(config.sqlite_wal_clean_second_timeout as u64); + let timer_interval = Duration::from_secs(config.sqlite_wal_clean_second_interval as u64); + let do_timer = config.sqlite_wal_clean_timer; + + tokio::spawn(async move { + let mut i = interval(timer_interval); + let mut s = signal(SignalKind::hangup()).unwrap(); + + loop { + if do_timer { + i.tick().await; + log::info!(target: "wal-trunc", "Timer ticked") + } else { + s.recv().await; + log::info!(target: "wal-trunc", "Received SIGHUP") + } + + if let Some(arc) = Weak::upgrade(&weak) { + log::info!(target: "wal-trunc", "Locking..."); + let guard = { + if let Ok(guard) = timeout(lock_timeout, arc.write()).await { + guard + } else { + log::info!(target: "wal-trunc", "Lock failed in timeout, canceled."); + continue; + } + }; + log::info!(target: "wal-trunc", "Locked, flushing..."); + let start = Instant::now(); + if let Err(e) = guard.flush_wal() { + log::error!(target: "wal-trunc", "Errored: {}", e); + } else { + log::info!(target: "wal-trunc", "Flushed in {:?}", start.elapsed()); + } + } else { + break; + } + } + }); + } } pub struct ReadGuard(OwnedRwLockReadGuard); diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 164d985..7e5490c 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -36,6 +36,7 @@ use tokio::sync::oneshot::Sender; struct Pool { writer: Mutex, readers: Vec>, + spill_tracker: Arc<()>, path: PathBuf, } @@ -43,7 +44,7 @@ pub const MILLI: Duration = Duration::from_millis(1); enum HoldingConn<'a> { FromGuard(MutexGuard<'a, Connection>), - FromOwned(Connection), + FromOwned(Connection, Arc<()>), } impl<'a> Deref for HoldingConn<'a> { @@ -52,29 +53,30 @@ impl<'a> Deref for HoldingConn<'a> { fn deref(&self) -> &Self::Target { match self { HoldingConn::FromGuard(guard) => guard.deref(), - HoldingConn::FromOwned(conn) => conn, + HoldingConn::FromOwned(conn, _) => conn, } } } impl Pool { - fn new>(path: P, num_readers: usize) -> Result { - let writer = Mutex::new(Self::prepare_conn(&path)?); + fn new>(path: P, num_readers: usize, cache_size: u32) -> Result { + let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); let mut readers = Vec::new(); for _ in 0..num_readers { - readers.push(Mutex::new(Self::prepare_conn(&path)?)) + readers.push(Mutex::new(Self::prepare_conn(&path, Some(cache_size))?)) } Ok(Self { writer, readers, + spill_tracker: Arc::new(()), path: path.as_ref().to_path_buf(), }) } - fn prepare_conn>(path: P) -> Result { + fn prepare_conn>(path: P, cache_size: Option) -> Result { let conn = Connection::open(path)?; conn.pragma_update(Some(Main), "journal_mode", &"WAL".to_owned())?; @@ -85,6 +87,10 @@ impl Pool { conn.pragma_update(Some(Main), "synchronous", &"OFF".to_owned())?; + if let Some(cache_kib) = cache_size { + conn.pragma_update(Some(Main), "cache_size", &(-Into::::into(cache_kib)))?; + } + Ok(conn) } @@ -99,11 +105,18 @@ impl Pool { } } - log::warn!("all readers locked, creating spillover reader..."); + let spill_arc = self.spill_tracker.clone(); + let now_count = Arc::strong_count(&spill_arc) - 1 /* because one is held by the pool */; + + log::warn!("read_lock: all readers locked, creating spillover reader..."); + + if now_count > 1 { + log::warn!("read_lock: now {} spillover readers exist", now_count); + } - let spilled = Self::prepare_conn(&self.path).unwrap(); + let spilled = Self::prepare_conn(&self.path, None).unwrap(); - return HoldingConn::FromOwned(spilled); + return HoldingConn::FromOwned(spilled, spill_arc); } } @@ -115,7 +128,8 @@ impl DatabaseEngine for SqliteEngine { fn open(config: &Config) -> Result> { let pool = Pool::new( format!("{}/conduit.db", &config.database_path), - num_cpus::get(), + config.sqlite_read_pool_size, + config.sqlite_cache_kib, )?; pool.write_lock() diff --git a/src/main.rs b/src/main.rs index 5c5ea84..22c44b5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -203,47 +203,7 @@ async fn main() { .expect("config is valid"); #[cfg(feature = "sqlite")] - { - use tokio::time::{interval, timeout}; - - use std::{ - sync::Weak, - time::{Duration, Instant}, - }; - - let weak: Weak> = Arc::downgrade(&db); - - tokio::spawn(async { - let weak = weak; - - let mut i = interval(Duration::from_secs(60)); - - loop { - i.tick().await; - - if let Some(arc) = Weak::upgrade(&weak) { - log::warn!("wal-trunc: locking..."); - let guard = { - if let Ok(guard) = timeout(Duration::from_secs(5), arc.write()).await { - guard - } else { - log::warn!("wal-trunc: lock failed in timeout, canceled."); - continue; - } - }; - log::warn!("wal-trunc: locked, flushing..."); - let start = Instant::now(); - if let Err(e) = guard.flush_wal() { - log::warn!("wal-trunc: errored: {}", e); - } else { - log::warn!("wal-trunc: flushed in {:?}", start.elapsed()); - } - } else { - break; - } - } - }); - } + Database::start_wal_clean_task(&db, &config).await; if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() From 494585267ad6f8751c9b951846b0904bb8320f95 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Wed, 7 Jul 2021 20:43:30 +0200 Subject: [PATCH 11/28] remove rjbench --- rjbench_testing/docker-compose.yml | 32 ------------------------------ 1 file changed, 32 deletions(-) delete mode 100644 rjbench_testing/docker-compose.yml diff --git a/rjbench_testing/docker-compose.yml b/rjbench_testing/docker-compose.yml deleted file mode 100644 index 44da5b5..0000000 --- a/rjbench_testing/docker-compose.yml +++ /dev/null @@ -1,32 +0,0 @@ -# Conduit -version: '3' - -services: - homeserver: - image: ubuntu:21.04 - restart: unless-stopped - working_dir: "/srv/conduit" - entrypoint: /srv/conduit/conduit - ports: - - 8448:8000 - volumes: - - ../target/db:/srv/conduit/.local/share/conduit - - ../target/debug/conduit:/srv/conduit/conduit - - ./conduit.toml:/srv/conduit/conduit.toml:ro - environment: - # CONDUIT_SERVER_NAME: localhost:8000 # replace with your own name - # CONDUIT_TRUSTED_SERVERS: '["matrix.org"]' - ### Uncomment and change values as desired - # CONDUIT_ADDRESS: 127.0.0.1 - # CONDUIT_PORT: 8000 - CONDUIT_CONFIG: '/srv/conduit/conduit.toml' - # Available levels are: error, warn, info, debug, trace - more info at: https://docs.rs/env_logger/*/env_logger/#enabling-logging - # CONDUIT_LOG: debug # default is: "info,rocket=off,_=off,sled=off" - # CONDUIT_ALLOW_JAEGER: 'false' - # CONDUIT_ALLOW_REGISTRATION : 'false' - # CONDUIT_ALLOW_ENCRYPTION: 'false' - # CONDUIT_ALLOW_FEDERATION: 'false' - # CONDUIT_DATABASE_PATH: /srv/conduit/.local/share/conduit - # CONDUIT_WORKERS: 10 - # CONDUIT_MAX_REQUEST_SIZE: 20_000_000 # in bytes, ~20 MB - From 0719377c6a0ab67da9449f164234586f923530dd Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Jul 2021 12:13:39 +0200 Subject: [PATCH 12/28] merge one more {use} --- src/database/abstraction/sqlite.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 7e5490c..54b2ffd 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,4 +1,5 @@ use std::{ + collections::BTreeMap, future::Future, ops::Deref, path::{Path, PathBuf}, @@ -12,8 +13,6 @@ use crate::{database::Config, Result}; use super::{DatabaseEngine, Tree}; -use std::collections::BTreeMap; - use log::debug; use crossbeam::channel::{bounded, Sender as ChannelSender}; From 7c82213ee7addd508b6593b7dff9e14965d60c20 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Jul 2021 14:27:28 +0200 Subject: [PATCH 13/28] change to use path joining properly --- src/database/abstraction/sqlite.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 54b2ffd..ff06fab 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -126,7 +126,7 @@ pub struct SqliteEngine { impl DatabaseEngine for SqliteEngine { fn open(config: &Config) -> Result> { let pool = Pool::new( - format!("{}/conduit.db", &config.database_path), + Path::new(&config.database_path).join("conduit.db"), config.sqlite_read_pool_size, config.sqlite_cache_kib, )?; From f4aabbdaa7002256c2ebc94b9d1ecaf261f24be5 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Thu, 8 Jul 2021 17:17:42 +0200 Subject: [PATCH 14/28] add some flushes --- src/client_server/membership.rs | 8 ++++++-- src/server_server.rs | 10 +++++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index 8743940..ae08af0 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -65,14 +65,18 @@ pub async fn join_room_by_id_route( servers.insert(body.room_id.server_name().to_owned()); - join_room_by_id_helper( + let ret = join_room_by_id_helper( &db, body.sender_user.as_ref(), &body.room_id, &servers, body.third_party_signed.as_ref(), ) - .await + .await; + + db.flush().await?; + + ret } #[cfg_attr( diff --git a/src/server_server.rs b/src/server_server.rs index 306afd2..bbb2f69 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -774,6 +774,8 @@ pub async fn send_transaction_message_route( } } + db.flush().await?; + Ok(send_transaction_message::v1::Response { pdus: resolved_map }.into()) } @@ -2160,6 +2162,8 @@ pub async fn create_join_event_route( db.sending.send_pdu(&server, &pdu_id)?; } + db.flush().await?; + Ok(create_join_event::v2::Response { room_state: RoomState { auth_chain: auth_chain_ids @@ -2276,6 +2280,8 @@ pub async fn create_invite_route( )?; } + db.flush().await?; + Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), } @@ -2389,7 +2395,7 @@ pub fn get_profile_information_route( post("/_matrix/federation/v1/user/keys/query", data = "") )] #[tracing::instrument(skip(db, body))] -pub fn get_keys_route( +pub async fn get_keys_route( db: ReadGuard, body: Ruma, ) -> ConduitResult { @@ -2404,6 +2410,8 @@ pub fn get_keys_route( &db, )?; + db.flush().await?; + Ok(get_keys::v1::Response { device_keys: result.device_keys, master_keys: result.master_keys, From 318d9c1a358197b2e1ac256eea1e54fd37179ee4 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Fri, 9 Jul 2021 14:43:04 +0200 Subject: [PATCH 15/28] revert docker-compose.yml file --- docker-compose.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docker-compose.yml b/docker-compose.yml index fe13fdc..a9a820f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -14,13 +14,13 @@ services: args: CREATED: '2021-03-16T08:18:27Z' VERSION: '0.1.0' - LOCAL: 'true' + LOCAL: 'false' GIT_REF: origin/master restart: unless-stopped ports: - 8448:8000 volumes: - - ./target/db:/srv/conduit/.local/share/conduit + - db:/srv/conduit/.local/share/conduit ### Uncomment if you want to use conduit.toml to configure Conduit ### Note: Set env vars will override conduit.toml values # - ./conduit.toml:/srv/conduit/conduit.toml @@ -55,5 +55,5 @@ services: # depends_on: # - homeserver -# volumes: -# db: +volumes: + db: From 3a76fda92b8694c12d2489f35eaa92213725035c Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 11 Jul 2021 15:41:10 +0200 Subject: [PATCH 16/28] incorperate feedback --- Cargo.toml | 1 - src/database.rs | 39 ++++++++++++++--------------- src/database/abstraction/rocksdb.rs | 8 +++--- src/database/abstraction/sled.rs | 6 ++--- src/database/abstraction/sqlite.rs | 28 ++++++++++----------- src/main.rs | 3 --- 6 files changed, 40 insertions(+), 45 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7edf641..896140c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,7 +25,6 @@ tokio = "1.2.0" # Used for storing data permanently sled = { version = "0.34.6", features = ["compression", "no_metrics"], optional = true } rocksdb = { version = "0.16.0", features = ["multi-threaded-cf"], optional = true } -# sqlx = { version = "0.5.5", features = ["sqlite", "runtime-tokio-rustls"], optional = true } #sled = { git = "https://github.com/spacejam/sled.git", rev = "e4640e0773595229f398438886f19bca6f7326a2", features = ["compression"] } # Used for the http request / response body type for Ruma endpoints used with reqwest diff --git a/src/database.rs b/src/database.rs index e5bb865..7fcee02 100644 --- a/src/database.rs +++ b/src/database.rs @@ -42,10 +42,8 @@ use self::proxy::ProxyConfig; pub struct Config { server_name: Box, database_path: String, - #[serde(default = "default_cache_capacity")] - cache_capacity: u32, - #[serde(default = "default_sqlite_cache_kib")] - sqlite_cache_kib: u32, + #[serde(default = "default_db_cache_capacity")] + db_cache_capacity: u32, #[serde(default = "default_sqlite_read_pool_size")] sqlite_read_pool_size: usize, #[serde(default = "false_fn")] @@ -83,14 +81,10 @@ fn true_fn() -> bool { true } -fn default_cache_capacity() -> u32 { +fn default_db_cache_capacity() -> u32 { 1024 * 1024 * 1024 } -fn default_sqlite_cache_kib() -> u32 { - 2000 -} - fn default_sqlite_read_pool_size() -> usize { num_cpus::get().max(1) } @@ -116,13 +110,13 @@ fn default_log() -> String { } #[cfg(feature = "sled")] -pub type Engine = abstraction::sled::SledEngine; +pub type Engine = abstraction::sled::Engine; #[cfg(feature = "rocksdb")] -pub type Engine = abstraction::rocksdb::RocksDbEngine; +pub type Engine = abstraction::rocksdb::Engine; #[cfg(feature = "sqlite")] -pub type Engine = abstraction::sqlite::SqliteEngine; +pub type Engine = abstraction::sqlite::Engine; pub struct Database { _db: Arc, @@ -268,7 +262,7 @@ impl Database { globals: globals::Globals::load( builder.open_tree("global")?, builder.open_tree("server_signingkeys")?, - config, + config.clone(), )?, })); @@ -372,6 +366,9 @@ impl Database { drop(guard); + #[cfg(feature = "sqlite")] + Self::start_wal_clean_task(&db, &config).await; + Ok(db) } @@ -481,6 +478,7 @@ impl Database { #[cfg(feature = "sqlite")] pub async fn start_wal_clean_task(lock: &Arc>, config: &Config) { use tokio::{ + select, signal::unix::{signal, SignalKind}, time::{interval, timeout}, }; @@ -501,13 +499,14 @@ impl Database { let mut s = signal(SignalKind::hangup()).unwrap(); loop { - if do_timer { - i.tick().await; - log::info!(target: "wal-trunc", "Timer ticked") - } else { - s.recv().await; - log::info!(target: "wal-trunc", "Received SIGHUP") - } + select! { + _ = i.tick(), if do_timer => { + log::info!(target: "wal-trunc", "Timer ticked") + } + _ = s.recv() => { + log::info!(target: "wal-trunc", "Received SIGHUP") + } + }; if let Some(arc) = Weak::upgrade(&weak) { log::info!(target: "wal-trunc", "Locking..."); diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 88b6297..b996130 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -7,15 +7,15 @@ use super::{DatabaseEngine, Tree}; use std::{collections::BTreeMap, sync::RwLock}; -pub struct RocksDbEngine(rocksdb::DBWithThreadMode); +pub struct Engine(rocksdb::DBWithThreadMode); pub struct RocksDbEngineTree<'a> { - db: Arc, + db: Arc, name: &'a str, watchers: RwLock, Vec>>>, } -impl DatabaseEngine for RocksDbEngine { +impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { let mut db_opts = rocksdb::Options::default(); db_opts.create_if_missing(true); @@ -45,7 +45,7 @@ impl DatabaseEngine for RocksDbEngine { .map(|name| rocksdb::ColumnFamilyDescriptor::new(name, options.clone())), )?; - Ok(Arc::new(RocksDbEngine(db))) + Ok(Arc::new(Engine(db))) } fn open_tree(self: &Arc, name: &'static str) -> Result> { diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index 557e8a0..8c7f80d 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -5,13 +5,13 @@ use std::{future::Future, pin::Pin, sync::Arc}; use super::{DatabaseEngine, Tree}; -pub struct SledEngine(sled::Db); +pub struct Engine(sled::Db); pub struct SledEngineTree(sled::Tree); -impl DatabaseEngine for SledEngine { +impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { - Ok(Arc::new(SledEngine( + Ok(Arc::new(Engine( sled::Config::default() .path(&config.database_path) .cache_capacity(config.cache_capacity as u64) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index ff06fab..fe54813 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -119,22 +119,22 @@ impl Pool { } } -pub struct SqliteEngine { +pub struct Engine { pool: Pool, } -impl DatabaseEngine for SqliteEngine { +impl DatabaseEngine for Engine { fn open(config: &Config) -> Result> { let pool = Pool::new( Path::new(&config.database_path).join("conduit.db"), config.sqlite_read_pool_size, - config.sqlite_cache_kib, + config.db_cache_capacity / 1024, // bytes -> kb )?; pool.write_lock() .execute("CREATE TABLE IF NOT EXISTS _noop (\"key\" INT)", params![])?; - let arc = Arc::new(SqliteEngine { pool }); + let arc = Arc::new(Engine { pool }); Ok(arc) } @@ -166,7 +166,7 @@ impl DatabaseEngine for SqliteEngine { } } -impl SqliteEngine { +impl Engine { pub fn flush_wal(self: &Arc) -> Result<()> { self.pool .write_lock() @@ -185,7 +185,7 @@ impl SqliteEngine { } pub struct SqliteTable { - engine: Arc, + engine: Arc, name: String, watchers: RwLock, Vec>>>, } @@ -257,19 +257,19 @@ impl Tree for SqliteTable { } fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - { - let guard = self.engine.pool.write_lock(); + let guard = self.engine.pool.write_lock(); - let start = Instant::now(); + let start = Instant::now(); - self.insert_with_guard(&guard, key, value)?; + self.insert_with_guard(&guard, key, value)?; - let elapsed = start.elapsed(); - if elapsed > MILLI { - debug!("insert: took {:012?} : {}", elapsed, &self.name); - } + let elapsed = start.elapsed(); + if elapsed > MILLI { + debug!("insert: took {:012?} : {}", elapsed, &self.name); } + drop(guard); + let watchers = self.watchers.read(); let mut triggered = Vec::new(); diff --git a/src/main.rs b/src/main.rs index 22c44b5..034c39e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -202,9 +202,6 @@ async fn main() { .await .expect("config is valid"); - #[cfg(feature = "sqlite")] - Database::start_wal_clean_task(&db, &config).await; - if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() .with_service_name("conduit") From bcfea98457d22fd58423bb336455713e453aa3e5 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 11 Jul 2021 15:42:06 +0200 Subject: [PATCH 17/28] replace ReadGuard with DatabaseGuard --- src/client_server/account.rs | 10 +++---- src/client_server/alias.rs | 8 +++--- src/client_server/backup.rs | 30 ++++++++++----------- src/client_server/config.rs | 10 +++---- src/client_server/context.rs | 4 +-- src/client_server/device.rs | 12 ++++----- src/client_server/directory.rs | 10 +++---- src/client_server/keys.rs | 14 +++++----- src/client_server/media.rs | 14 ++++++---- src/client_server/membership.rs | 24 ++++++++--------- src/client_server/message.rs | 6 ++--- src/client_server/presence.rs | 6 ++--- src/client_server/profile.rs | 12 ++++----- src/client_server/push.rs | 22 +++++++-------- src/client_server/read_marker.rs | 6 ++--- src/client_server/redact.rs | 4 +-- src/client_server/room.rs | 9 ++++--- src/client_server/search.rs | 4 +-- src/client_server/session.rs | 8 +++--- src/client_server/state.rs | 14 +++++----- src/client_server/sync.rs | 8 +++--- src/client_server/tag.rs | 8 +++--- src/client_server/to_device.rs | 4 +-- src/client_server/typing.rs | 4 +-- src/client_server/user_directory.rs | 4 +-- src/database.rs | 14 +++++----- src/ruma_wrapper.rs | 4 +-- src/server_server.rs | 42 +++++++++++++++-------------- 28 files changed, 162 insertions(+), 153 deletions(-) diff --git a/src/client_server/account.rs b/src/client_server/account.rs index d76816b..e5268ec 100644 --- a/src/client_server/account.rs +++ b/src/client_server/account.rs @@ -1,7 +1,7 @@ use std::{collections::BTreeMap, convert::TryInto}; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -42,7 +42,7 @@ const GUEST_NAME_LENGTH: usize = 10; )] #[tracing::instrument(skip(db, body))] pub async fn get_register_available_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { // Validate user id @@ -85,7 +85,7 @@ pub async fn get_register_available_route( )] #[tracing::instrument(skip(db, body))] pub async fn register_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_registration() && !body.from_appservice { @@ -496,7 +496,7 @@ pub async fn register_route( )] #[tracing::instrument(skip(db, body))] pub async fn change_password_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -588,7 +588,7 @@ pub async fn whoami_route(body: Ruma) -> ConduitResult>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/alias.rs b/src/client_server/alias.rs index 440b505..f5d9f64 100644 --- a/src/client_server/alias.rs +++ b/src/client_server/alias.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Database, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Ruma}; use regex::Regex; use ruma::{ api::{ @@ -21,7 +21,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_alias_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if db.rooms.id_from_alias(&body.room_alias)?.is_some() { @@ -42,7 +42,7 @@ pub async fn create_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_alias_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { db.rooms.set_alias(&body.room_alias, None, &db.globals)?; @@ -58,7 +58,7 @@ pub async fn delete_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_alias_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { get_alias_helper(&db, &body.room_alias).await diff --git a/src/client_server/backup.rs b/src/client_server/backup.rs index 8412778..ccb17fa 100644 --- a/src/client_server/backup.rs +++ b/src/client_server/backup.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::backup::{ @@ -18,7 +18,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn create_backup_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -37,7 +37,7 @@ pub async fn create_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_backup_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -55,7 +55,7 @@ pub async fn update_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_latest_backup_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -83,7 +83,7 @@ pub async fn get_latest_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -110,7 +110,7 @@ pub async fn get_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -129,7 +129,7 @@ pub async fn delete_backup_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -163,7 +163,7 @@ pub async fn add_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_sessions_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -195,7 +195,7 @@ pub async fn add_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn add_backup_key_session_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -224,7 +224,7 @@ pub async fn add_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -240,7 +240,7 @@ pub async fn get_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_sessions_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -258,7 +258,7 @@ pub async fn get_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_backup_key_session_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -280,7 +280,7 @@ pub async fn get_backup_key_session_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -303,7 +303,7 @@ pub async fn delete_backup_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_sessions_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -326,7 +326,7 @@ pub async fn delete_backup_key_sessions_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_backup_key_session_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/config.rs b/src/client_server/config.rs index da504e1..4f33689 100644 --- a/src/client_server/config.rs +++ b/src/client_server/config.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -22,7 +22,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_global_account_data_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -57,7 +57,7 @@ pub async fn set_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_account_data_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -89,7 +89,7 @@ pub async fn set_room_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_global_account_data_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -116,7 +116,7 @@ pub async fn get_global_account_data_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_account_data_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/context.rs b/src/client_server/context.rs index 7d857e1..dbc121e 100644 --- a/src/client_server/context.rs +++ b/src/client_server/context.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::context::get_context}; use std::convert::TryFrom; @@ -11,7 +11,7 @@ use rocket::get; )] #[tracing::instrument(skip(db, body))] pub async fn get_context_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/device.rs b/src/client_server/device.rs index 4d8d16c..46ee7d7 100644 --- a/src/client_server/device.rs +++ b/src/client_server/device.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, utils, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use ruma::api::client::{ error::ErrorKind, r0::{ @@ -17,7 +17,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_devices_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -37,7 +37,7 @@ pub async fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_device_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -56,7 +56,7 @@ pub async fn get_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn update_device_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -82,7 +82,7 @@ pub async fn update_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_device_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -136,7 +136,7 @@ pub async fn delete_device_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_devices_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/directory.rs b/src/client_server/directory.rs index 5cfc458..4a440fd 100644 --- a/src/client_server/directory.rs +++ b/src/client_server/directory.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma}; use log::info; use ruma::{ api::{ @@ -32,7 +32,7 @@ use rocket::{get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { get_public_rooms_filtered_helper( @@ -52,7 +52,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let response = get_public_rooms_filtered_helper( @@ -81,7 +81,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_room_visibility_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -111,7 +111,7 @@ pub async fn set_room_visibility_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_visibility_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_room_visibility::Response { diff --git a/src/client_server/keys.rs b/src/client_server/keys.rs index 39cf7be..6726ba5 100644 --- a/src/client_server/keys.rs +++ b/src/client_server/keys.rs @@ -1,5 +1,5 @@ use super::SESSION_ID_LENGTH; -use crate::{database::ReadGuard, utils, ConduitResult, Database, Error, Result, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Database, Error, Result, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -25,7 +25,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn upload_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -74,7 +74,7 @@ pub async fn upload_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -95,7 +95,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let response = claim_keys_helper(&body.one_time_keys, &db)?; @@ -111,7 +111,7 @@ pub async fn claim_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signing_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -174,7 +174,7 @@ pub async fn upload_signing_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn upload_signatures_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -235,7 +235,7 @@ pub async fn upload_signatures_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_key_changes_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/media.rs b/src/client_server/media.rs index 588e5f7..eaaf939 100644 --- a/src/client_server/media.rs +++ b/src/client_server/media.rs @@ -1,4 +1,6 @@ -use crate::{database::media::FileMeta, database::ReadGuard, utils, ConduitResult, Error, Ruma}; +use crate::{ + database::media::FileMeta, database::DatabaseGuard, utils, ConduitResult, Error, Ruma, +}; use ruma::api::client::{ error::ErrorKind, r0::media::{create_content, get_content, get_content_thumbnail, get_media_config}, @@ -12,7 +14,9 @@ const MXC_LENGTH: usize = 32; #[cfg_attr(feature = "conduit_bin", get("/_matrix/media/r0/config"))] #[tracing::instrument(skip(db))] -pub async fn get_media_config_route(db: ReadGuard) -> ConduitResult { +pub async fn get_media_config_route( + db: DatabaseGuard, +) -> ConduitResult { Ok(get_media_config::Response { upload_size: db.globals.max_request_size().into(), } @@ -25,7 +29,7 @@ pub async fn get_media_config_route(db: ReadGuard) -> ConduitResult>, ) -> ConduitResult { let mxc = format!( @@ -63,7 +67,7 @@ pub async fn create_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -116,7 +120,7 @@ pub async fn get_content_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_content_thumbnail_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); diff --git a/src/client_server/membership.rs b/src/client_server/membership.rs index ae08af0..4667f25 100644 --- a/src/client_server/membership.rs +++ b/src/client_server/membership.rs @@ -1,6 +1,6 @@ use crate::{ client_server, - database::ReadGuard, + database::DatabaseGuard, pdu::{PduBuilder, PduEvent}, server_server, utils, ConduitResult, Database, Error, Result, Ruma, }; @@ -44,7 +44,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +85,7 @@ pub async fn join_room_by_id_route( )] #[tracing::instrument(skip(db, body))] pub async fn join_room_by_id_or_alias_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -139,7 +139,7 @@ pub async fn join_room_by_id_or_alias_route( )] #[tracing::instrument(skip(db, body))] pub async fn leave_room_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -157,7 +157,7 @@ pub async fn leave_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn invite_user_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -177,7 +177,7 @@ pub async fn invite_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn kick_user_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -227,7 +227,7 @@ pub async fn kick_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn ban_user_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -285,7 +285,7 @@ pub async fn ban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn unban_user_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -334,7 +334,7 @@ pub async fn unban_user_route( )] #[tracing::instrument(skip(db, body))] pub async fn forget_room_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -352,7 +352,7 @@ pub async fn forget_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_rooms_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -373,7 +373,7 @@ pub async fn joined_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_member_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -403,7 +403,7 @@ pub async fn get_member_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn joined_members_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/message.rs b/src/client_server/message.rs index d439535..7e898b1 100644 --- a/src/client_server/message.rs +++ b/src/client_server/message.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -21,7 +21,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_message_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -84,7 +84,7 @@ pub async fn send_message_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_message_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/presence.rs b/src/client_server/presence.rs index c96e62b..bfe638f 100644 --- a/src/client_server/presence.rs +++ b/src/client_server/presence.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, utils, ConduitResult, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Ruma}; use ruma::api::client::r0::presence::{get_presence, set_presence}; use std::{convert::TryInto, time::Duration}; @@ -11,7 +11,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_presence_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -52,7 +52,7 @@ pub async fn set_presence_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_presence_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/profile.rs b/src/client_server/profile.rs index f516f91..5281a4a 100644 --- a/src/client_server/profile.rs +++ b/src/client_server/profile.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -20,7 +20,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn set_displayname_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -107,7 +107,7 @@ pub async fn set_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_displayname_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_display_name::Response { @@ -122,7 +122,7 @@ pub async fn get_displayname_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_avatar_url_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -209,7 +209,7 @@ pub async fn set_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_avatar_url_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { Ok(get_avatar_url::Response { @@ -224,7 +224,7 @@ pub async fn get_avatar_url_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_profile_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.users.exists(&body.user_id)? { diff --git a/src/client_server/push.rs b/src/client_server/push.rs index e718287..794cbce 100644 --- a/src/client_server/push.rs +++ b/src/client_server/push.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -21,7 +21,7 @@ use rocket::{delete, get, post, put}; )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrules_all_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -46,7 +46,7 @@ pub async fn get_pushrules_all_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -100,7 +100,7 @@ pub async fn get_pushrule_route( )] #[tracing::instrument(skip(db, req))] pub async fn set_pushrule_route( - db: ReadGuard, + db: DatabaseGuard, req: Ruma>, ) -> ConduitResult { let sender_user = req.sender_user.as_ref().expect("user is authenticated"); @@ -203,7 +203,7 @@ pub async fn set_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_actions_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -262,7 +262,7 @@ pub async fn get_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_actions_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -336,7 +336,7 @@ pub async fn set_pushrule_actions_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushrule_enabled_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -397,7 +397,7 @@ pub async fn get_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushrule_enabled_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -476,7 +476,7 @@ pub async fn set_pushrule_enabled_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_pushrule_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -545,7 +545,7 @@ pub async fn delete_pushrule_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_pushers_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -562,7 +562,7 @@ pub async fn get_pushers_route( )] #[tracing::instrument(skip(db, body))] pub async fn set_pushers_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/read_marker.rs b/src/client_server/read_marker.rs index 64f8e26..fe49af9 100644 --- a/src/client_server/read_marker.rs +++ b/src/client_server/read_marker.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{ error::ErrorKind, @@ -19,7 +19,7 @@ use rocket::post; )] #[tracing::instrument(skip(db, body))] pub async fn set_read_marker_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -86,7 +86,7 @@ pub async fn set_read_marker_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_receipt_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/redact.rs b/src/client_server/redact.rs index af1b242..3db2771 100644 --- a/src/client_server/redact.rs +++ b/src/client_server/redact.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Ruma}; +use crate::{database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Ruma}; use ruma::{ api::client::r0::redact::redact_event, events::{room::redaction, EventType}, @@ -13,7 +13,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn redact_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/room.rs b/src/client_server/room.rs index c01c5f8..43625fe 100644 --- a/src/client_server/room.rs +++ b/src/client_server/room.rs @@ -1,5 +1,6 @@ use crate::{ - client_server::invite_helper, database::ReadGuard, pdu::PduBuilder, ConduitResult, Error, Ruma, + client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Error, + Ruma, }; use log::info; use ruma::{ @@ -25,7 +26,7 @@ use rocket::{get, post}; )] #[tracing::instrument(skip(db, body))] pub async fn create_room_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -295,7 +296,7 @@ pub async fn create_room_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_room_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -323,7 +324,7 @@ pub async fn get_room_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn upgrade_room_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, _room_id: String, ) -> ConduitResult { diff --git a/src/client_server/search.rs b/src/client_server/search.rs index cfe1345..ec23dd4 100644 --- a/src/client_server/search.rs +++ b/src/client_server/search.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::api::client::{error::ErrorKind, r0::search::search_events}; #[cfg(feature = "conduit_bin")] @@ -12,7 +12,7 @@ use std::collections::BTreeMap; )] #[tracing::instrument(skip(db, body))] pub async fn search_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/session.rs b/src/client_server/session.rs index 844ef0c..7ad792b 100644 --- a/src/client_server/session.rs +++ b/src/client_server/session.rs @@ -1,5 +1,5 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::ReadGuard, utils, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Error, Ruma}; use log::info; use ruma::{ api::client::{ @@ -50,7 +50,7 @@ pub async fn get_login_types_route() -> ConduitResult )] #[tracing::instrument(skip(db, body))] pub async fn login_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { // Validate login method @@ -167,7 +167,7 @@ pub async fn login_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -195,7 +195,7 @@ pub async fn logout_route( )] #[tracing::instrument(skip(db, body))] pub async fn logout_all_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/state.rs b/src/client_server/state.rs index dea52aa..68246d5 100644 --- a/src/client_server/state.rs +++ b/src/client_server/state.rs @@ -1,4 +1,6 @@ -use crate::{database::ReadGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma}; +use crate::{ + database::DatabaseGuard, pdu::PduBuilder, ConduitResult, Database, Error, Result, Ruma, +}; use ruma::{ api::client::{ error::ErrorKind, @@ -24,7 +26,7 @@ use rocket::{get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_key_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -50,7 +52,7 @@ pub async fn send_state_event_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_state_event_for_empty_key_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -76,7 +78,7 @@ pub async fn send_state_event_for_empty_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -123,7 +125,7 @@ pub async fn get_state_events_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_key_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -174,7 +176,7 @@ pub async fn get_state_events_for_key_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_state_events_for_empty_key_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/sync.rs b/src/client_server/sync.rs index 09ac8f0..c57f1da 100644 --- a/src/client_server/sync.rs +++ b/src/client_server/sync.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; +use crate::{database::DatabaseGuard, ConduitResult, Database, Error, Result, Ruma, RumaResponse}; use log::error; use ruma::{ api::client::r0::{sync::sync_events, uiaa::UiaaResponse}, @@ -34,7 +34,7 @@ use rocket::{get, tokio}; )] #[tracing::instrument(skip(db, body))] pub async fn sync_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> std::result::Result, RumaResponse> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -105,7 +105,7 @@ pub async fn sync_events_route( } pub async fn sync_helper_wrapper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, @@ -149,7 +149,7 @@ pub async fn sync_helper_wrapper( } async fn sync_helper( - db: Arc, + db: Arc, sender_user: UserId, sender_device: Box, since: Option, diff --git a/src/client_server/tag.rs b/src/client_server/tag.rs index 8fa1d05..17df2c2 100644 --- a/src/client_server/tag.rs +++ b/src/client_server/tag.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Ruma}; use ruma::{ api::client::r0::tag::{create_tag, delete_tag, get_tags}, events::EventType, @@ -14,7 +14,7 @@ use rocket::{delete, get, put}; )] #[tracing::instrument(skip(db, body))] pub async fn update_tag_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -51,7 +51,7 @@ pub async fn update_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn delete_tag_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -85,7 +85,7 @@ pub async fn delete_tag_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_tags_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/to_device.rs b/src/client_server/to_device.rs index 761a359..3bb135e 100644 --- a/src/client_server/to_device.rs +++ b/src/client_server/to_device.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Error, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Error, Ruma}; use ruma::{ api::client::{error::ErrorKind, r0::to_device::send_event_to_device}, to_device::DeviceIdOrAllDevices, @@ -13,7 +13,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub async fn send_event_to_device_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/typing.rs b/src/client_server/typing.rs index 8f3a643..7a590af 100644 --- a/src/client_server/typing.rs +++ b/src/client_server/typing.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, utils, ConduitResult, Ruma}; +use crate::{database::DatabaseGuard, utils, ConduitResult, Ruma}; use create_typing_event::Typing; use ruma::api::client::r0::typing::create_typing_event; @@ -11,7 +11,7 @@ use rocket::put; )] #[tracing::instrument(skip(db, body))] pub fn create_typing_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); diff --git a/src/client_server/user_directory.rs b/src/client_server/user_directory.rs index 226fe6e..14b85a6 100644 --- a/src/client_server/user_directory.rs +++ b/src/client_server/user_directory.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, ConduitResult, Ruma}; +use crate::{database::DatabaseGuard, ConduitResult, Ruma}; use ruma::api::client::r0::user_directory::search_users; #[cfg(feature = "conduit_bin")] @@ -10,7 +10,7 @@ use rocket::post; )] #[tracing::instrument(skip(db, body))] pub async fn search_users_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { let limit = u64::from(body.limit) as usize; diff --git a/src/database.rs b/src/database.rs index 7fcee02..036ec59 100644 --- a/src/database.rs +++ b/src/database.rs @@ -533,9 +533,9 @@ impl Database { } } -pub struct ReadGuard(OwnedRwLockReadGuard); +pub struct DatabaseGuard(OwnedRwLockReadGuard); -impl Deref for ReadGuard { +impl Deref for DatabaseGuard { type Target = OwnedRwLockReadGuard; fn deref(&self) -> &Self::Target { @@ -544,18 +544,18 @@ impl Deref for ReadGuard { } #[rocket::async_trait] -impl<'r> FromRequest<'r> for ReadGuard { +impl<'r> FromRequest<'r> for DatabaseGuard { type Error = (); async fn from_request(req: &'r Request<'_>) -> rocket::request::Outcome { let db = try_outcome!(req.guard::>>>().await); - Ok(ReadGuard(Arc::clone(&db).read_owned().await)).or_forward(()) + Ok(DatabaseGuard(Arc::clone(&db).read_owned().await)).or_forward(()) } } -impl Into for OwnedRwLockReadGuard { - fn into(self) -> ReadGuard { - ReadGuard(self) +impl Into for OwnedRwLockReadGuard { + fn into(self) -> DatabaseGuard { + DatabaseGuard(self) } } diff --git a/src/ruma_wrapper.rs b/src/ruma_wrapper.rs index d06f224..347406d 100644 --- a/src/ruma_wrapper.rs +++ b/src/ruma_wrapper.rs @@ -1,4 +1,4 @@ -use crate::{database::ReadGuard, Error}; +use crate::{database::DatabaseGuard, Error}; use ruma::{ api::{client::r0::uiaa::UiaaResponse, OutgoingResponse}, identifiers::{DeviceId, UserId}, @@ -48,7 +48,7 @@ where async fn from_data(request: &'a Request<'_>, data: Data) -> data::Outcome { let metadata = T::Incoming::METADATA; let db = request - .guard::() + .guard::() .await .expect("database was loaded"); diff --git a/src/server_server.rs b/src/server_server.rs index bbb2f69..25d170d 100644 --- a/src/server_server.rs +++ b/src/server_server.rs @@ -1,6 +1,6 @@ use crate::{ client_server::{self, claim_keys_helper, get_keys_helper}, - database::ReadGuard, + database::DatabaseGuard, utils, ConduitResult, Database, Error, PduEvent, Result, Ruma, }; use get_profile_information::v1::ProfileField; @@ -432,7 +432,9 @@ pub async fn request_well_known( #[cfg_attr(feature = "conduit_bin", get("/_matrix/federation/v1/version"))] #[tracing::instrument(skip(db))] -pub fn get_server_version_route(db: ReadGuard) -> ConduitResult { +pub fn get_server_version_route( + db: DatabaseGuard, +) -> ConduitResult { if !db.globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -449,7 +451,7 @@ pub fn get_server_version_route(db: ReadGuard) -> ConduitResult Json { +pub fn get_server_keys_route(db: DatabaseGuard) -> Json { if !db.globals.allow_federation() { // TODO: Use proper types return Json("Federation is disabled.".to_owned()); @@ -496,7 +498,7 @@ pub fn get_server_keys_route(db: ReadGuard) -> Json { #[cfg_attr(feature = "conduit_bin", get("/_matrix/key/v2/server/<_>"))] #[tracing::instrument(skip(db))] -pub fn get_server_keys_deprecated_route(db: ReadGuard) -> Json { +pub fn get_server_keys_deprecated_route(db: DatabaseGuard) -> Json { get_server_keys_route(db) } @@ -506,7 +508,7 @@ pub fn get_server_keys_deprecated_route(db: ReadGuard) -> Json { )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_filtered_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -550,7 +552,7 @@ pub async fn get_public_rooms_filtered_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_public_rooms_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -594,7 +596,7 @@ pub async fn get_public_rooms_route( )] #[tracing::instrument(skip(db, body))] pub async fn send_transaction_message_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1675,7 +1677,7 @@ pub(crate) fn append_incoming_pdu( )] #[tracing::instrument(skip(db, body))] pub fn get_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1700,7 +1702,7 @@ pub fn get_event_route( )] #[tracing::instrument(skip(db, body))] pub fn get_missing_events_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1749,7 +1751,7 @@ pub fn get_missing_events_route( )] #[tracing::instrument(skip(db, body))] pub fn get_event_authorization_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1793,7 +1795,7 @@ pub fn get_event_authorization_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1856,7 +1858,7 @@ pub fn get_room_state_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_state_ids_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -1908,7 +1910,7 @@ pub fn get_room_state_ids_route( )] #[tracing::instrument(skip(db, body))] pub fn create_join_event_template_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2077,7 +2079,7 @@ pub fn create_join_event_template_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_join_event_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2187,7 +2189,7 @@ pub async fn create_join_event_route( )] #[tracing::instrument(skip(db, body))] pub async fn create_invite_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2294,7 +2296,7 @@ pub async fn create_invite_route( )] #[tracing::instrument(skip(db, body))] pub fn get_devices_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2334,7 +2336,7 @@ pub fn get_devices_route( )] #[tracing::instrument(skip(db, body))] pub fn get_room_information_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2362,7 +2364,7 @@ pub fn get_room_information_route( )] #[tracing::instrument(skip(db, body))] pub fn get_profile_information_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma>, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2396,7 +2398,7 @@ pub fn get_profile_information_route( )] #[tracing::instrument(skip(db, body))] pub async fn get_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { @@ -2426,7 +2428,7 @@ pub async fn get_keys_route( )] #[tracing::instrument(skip(db, body))] pub async fn claim_keys_route( - db: ReadGuard, + db: DatabaseGuard, body: Ruma, ) -> ConduitResult { if !db.globals.allow_federation() { From 7e9014d5c926ba211bdf1e68ee343321c47d86ac Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 11 Jul 2021 20:10:24 +0200 Subject: [PATCH 18/28] implement sync rotation --- src/database.rs | 7 +++++++ src/database/abstraction/sqlite.rs | 8 ++++---- src/database/globals.rs | 27 ++++++++++++++++++++++++++- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/src/database.rs b/src/database.rs index 036ec59..e9646df 100644 --- a/src/database.rs +++ b/src/database.rs @@ -456,6 +456,8 @@ impl Database { .watch_prefix(&userid_bytes), ); + futures.push(Box::pin(self.globals.rotate.watch())); + // Wait until one of them finds something futures.next().await; } @@ -509,6 +511,11 @@ impl Database { }; if let Some(arc) = Weak::upgrade(&weak) { + log::info!(target: "wal-trunc", "Rotating sync helpers..."); + // This actually creates a very small race condition between firing this and trying to acquire the subsequent write lock. + // Though it is not a huge deal if the write lock doesn't "catch", as it'll harmlessly time out. + arc.read().await.globals.rotate.fire(); + log::info!(target: "wal-trunc", "Locking..."); let guard = { if let Ok(guard) = timeout(lock_timeout, arc.write()).await { diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index fe54813..310e03a 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -16,7 +16,7 @@ use super::{DatabaseEngine, Tree}; use log::debug; use crossbeam::channel::{bounded, Sender as ChannelSender}; -use parking_lot::{Mutex, MutexGuard, RwLock}; +use parking_lot::{FairMutex, FairMutexGuard, Mutex, MutexGuard, RwLock}; use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; use tokio::sync::oneshot::Sender; @@ -33,7 +33,7 @@ use tokio::sync::oneshot::Sender; // "SELECT key, value FROM {} WHERE key <= ? ORDER BY DESC"; struct Pool { - writer: Mutex, + writer: FairMutex, readers: Vec>, spill_tracker: Arc<()>, path: PathBuf, @@ -59,7 +59,7 @@ impl<'a> Deref for HoldingConn<'a> { impl Pool { fn new>(path: P, num_readers: usize, cache_size: u32) -> Result { - let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); + let writer = FairMutex::new(Self::prepare_conn(&path, Some(cache_size))?); let mut readers = Vec::new(); @@ -93,7 +93,7 @@ impl Pool { Ok(conn) } - fn write_lock(&self) -> MutexGuard<'_, Connection> { + fn write_lock(&self) -> FairMutexGuard<'_, Connection> { self.writer.lock() } diff --git a/src/database/globals.rs b/src/database/globals.rs index eef478a..7c53072 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -11,11 +11,12 @@ use rustls::{ServerCertVerifier, WebPKIVerifier}; use std::{ collections::{BTreeMap, HashMap}, fs, + future::Future, path::PathBuf, sync::{Arc, RwLock}, time::{Duration, Instant}, }; -use tokio::sync::Semaphore; +use tokio::sync::{broadcast, Semaphore}; use trust_dns_resolver::TokioAsyncResolver; use super::abstraction::Tree; @@ -47,6 +48,7 @@ pub struct Globals { ), // since, rx >, >, + pub rotate: RotationHandler, } struct MatrixServerVerifier { @@ -82,6 +84,28 @@ impl ServerCertVerifier for MatrixServerVerifier { } } +pub struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>); + +impl RotationHandler { + pub fn new() -> Self { + let (s, r) = broadcast::channel::<()>(1); + + Self(s, r) + } + + pub fn watch(&self) -> impl Future { + let mut r = self.0.subscribe(); + + async move { + let _ = r.recv().await; + } + } + + pub fn fire(&self) { + let _ = self.0.send(()); + } +} + impl Globals { pub fn load( globals: Arc, @@ -168,6 +192,7 @@ impl Globals { bad_signature_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), servername_ratelimiter: Arc::new(RwLock::new(BTreeMap::new())), sync_receivers: RwLock::new(BTreeMap::new()), + rotate: RotationHandler::new(), }; fs::create_dir_all(s.get_media_folder())?; From caa0cbfe1d932beb1f2a704f547e9d5b46355bcd Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Sun, 11 Jul 2021 21:48:55 +0200 Subject: [PATCH 19/28] change fairmutex to mutex --- src/database/abstraction/sqlite.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 310e03a..fe54813 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -16,7 +16,7 @@ use super::{DatabaseEngine, Tree}; use log::debug; use crossbeam::channel::{bounded, Sender as ChannelSender}; -use parking_lot::{FairMutex, FairMutexGuard, Mutex, MutexGuard, RwLock}; +use parking_lot::{Mutex, MutexGuard, RwLock}; use rusqlite::{params, Connection, DatabaseName::Main, OptionalExtension}; use tokio::sync::oneshot::Sender; @@ -33,7 +33,7 @@ use tokio::sync::oneshot::Sender; // "SELECT key, value FROM {} WHERE key <= ? ORDER BY DESC"; struct Pool { - writer: FairMutex, + writer: Mutex, readers: Vec>, spill_tracker: Arc<()>, path: PathBuf, @@ -59,7 +59,7 @@ impl<'a> Deref for HoldingConn<'a> { impl Pool { fn new>(path: P, num_readers: usize, cache_size: u32) -> Result { - let writer = FairMutex::new(Self::prepare_conn(&path, Some(cache_size))?); + let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); let mut readers = Vec::new(); @@ -93,7 +93,7 @@ impl Pool { Ok(conn) } - fn write_lock(&self) -> FairMutexGuard<'_, Connection> { + fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } From 735d7a08155a476ff627ec78dcb2812bd9067402 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 12 Jul 2021 00:07:10 +0200 Subject: [PATCH 20/28] database iter_from fix --- src/database/abstraction/sqlite.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index fe54813..d4ab9ad 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -334,8 +334,8 @@ impl Tree for SqliteTable { if backwards { iter_from_thread!( self, - format!( // TODO change to <= on rebase - "SELECT key, value FROM {} WHERE key < ? ORDER BY key DESC", + format!( + "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", name ) .as_str(), From 6e8beb604dbfa0cc89b0eb123711e842029bdfb2 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 12 Jul 2021 18:05:43 +0200 Subject: [PATCH 21/28] support some deprecations --- src/database.rs | 28 ++++++++++++++++++++++++++-- src/database/abstraction/sqlite.rs | 2 +- src/main.rs | 3 ++- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/src/database.rs b/src/database.rs index e9646df..c58a396 100644 --- a/src/database.rs +++ b/src/database.rs @@ -42,8 +42,8 @@ use self::proxy::ProxyConfig; pub struct Config { server_name: Box, database_path: String, - #[serde(default = "default_db_cache_capacity")] - db_cache_capacity: u32, + cache_capacity: Option, // deprecated + db_cache_capacity: Option, #[serde(default = "default_sqlite_read_pool_size")] sqlite_read_pool_size: usize, #[serde(default = "false_fn")] @@ -73,6 +73,30 @@ pub struct Config { pub log: String, } +macro_rules! deprecate_with { + ($self:expr ; $from:ident -> $to:ident) => { + if let Some(v) = $self.$from { + let from = stringify!($from); + let to = stringify!($to); + log::warn!("{} is deprecated, use {} instead", from, to); + + $self.$to.get_or_insert(v); + } + }; + ($self:expr ; $from:ident -> $to:ident or $default:expr) => { + deprecate_with!($self ; $from -> $to); + $self.$to.get_or_insert_with($default); + }; +} + +impl Config { + pub fn fallbacks(mut self) -> Self { + // TODO: have a proper way handle into above struct (maybe serde supports something like this?) + deprecate_with!(self ; cache_capacity -> db_cache_capacity or default_db_cache_capacity); + self + } +} + fn false_fn() -> bool { false } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index d4ab9ad..6864287 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -128,7 +128,7 @@ impl DatabaseEngine for Engine { let pool = Pool::new( Path::new(&config.database_path).join("conduit.db"), config.sqlite_read_pool_size, - config.db_cache_capacity / 1024, // bytes -> kb + config.db_cache_capacity.expect("fallbacks hasn't been called") / 1024, // bytes -> kb )?; pool.write_lock() diff --git a/src/main.rs b/src/main.rs index 034c39e..fd57468 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,7 +196,8 @@ async fn main() { let config = raw_config .extract::() - .expect("It looks like your config is invalid. Please take a look at the error"); + .expect("It looks like your config is invalid. Please take a look at the error") + .fallbacks(); let db = Database::load_or_create(config.clone()) .await From 7e0aab785256a3bfe2595e2622bd6e9d671c6fb3 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 12 Jul 2021 19:09:14 +0200 Subject: [PATCH 22/28] shuffle main.rs to allow deprecation warnings --- src/database.rs | 3 +-- src/main.rs | 34 +++++++++++++++++++--------------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/database.rs b/src/database.rs index c58a396..20049d1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -90,10 +90,9 @@ macro_rules! deprecate_with { } impl Config { - pub fn fallbacks(mut self) -> Self { + pub fn process_fallbacks(&mut self) { // TODO: have a proper way handle into above struct (maybe serde supports something like this?) deprecate_with!(self ; cache_capacity -> db_cache_capacity or default_db_cache_capacity); - self } } diff --git a/src/main.rs b/src/main.rs index fd57468..84dc4cb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -194,14 +194,14 @@ async fn main() { ) .merge(Env::prefixed("CONDUIT_").global()); - let config = raw_config + std::env::set_var("RUST_LOG", "warn"); + + let mut config = raw_config .extract::() - .expect("It looks like your config is invalid. Please take a look at the error") - .fallbacks(); + .expect("It looks like your config is invalid. Please take a look at the error"); - let db = Database::load_or_create(config.clone()) - .await - .expect("config is valid"); + let mut _span: Option = None; + let mut _enter: Option> = None; if config.allow_jaeger { let (tracer, _uninstall) = opentelemetry_jaeger::new_pipeline() @@ -211,18 +211,22 @@ async fn main() { let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); Registry::default().with(telemetry).try_init().unwrap(); - let root = span!(tracing::Level::INFO, "app_start", work_units = 2); - let _enter = root.enter(); - - let rocket = setup_rocket(raw_config, db); - rocket.launch().await.unwrap(); + _span = Some(span!(tracing::Level::INFO, "app_start", work_units = 2)); + _enter = Some(_span.as_ref().unwrap().enter()); } else { - std::env::set_var("RUST_LOG", config.log); + std::env::set_var("RUST_LOG", &config.log); tracing_subscriber::fmt::init(); - - let rocket = setup_rocket(raw_config, db); - rocket.launch().await.unwrap(); } + + // Required here to process fallbacks while logging is enabled, but before config is actually used for anything + config.process_fallbacks(); + + let db = Database::load_or_create(config) + .await + .expect("config is valid"); + + let rocket = setup_rocket(raw_config, db); + rocket.launch().await.unwrap(); } #[catch(404)] From 3260ae01b8a5e43c111908c6715ca091b00fa025 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 12 Jul 2021 19:53:06 +0200 Subject: [PATCH 23/28] change references of cache_capacity to db_cache_capacity --- DEPLOY.md | 2 +- conduit-example.toml | 2 +- debian/postinst | 2 +- src/database/abstraction/sled.rs | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/DEPLOY.md b/DEPLOY.md index fe8c331..52464e1 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -106,7 +106,7 @@ allow_federation = true trusted_servers = ["matrix.org"] -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #workers = 4 # default: cpu core count * 2 diff --git a/conduit-example.toml b/conduit-example.toml index db0bbb7..2df8854 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -35,7 +35,7 @@ max_request_size = 20_000_000 # in bytes trusted_servers = ["matrix.org"] -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 diff --git a/debian/postinst b/debian/postinst index 6a4cdb8..4c33e7d 100644 --- a/debian/postinst +++ b/debian/postinst @@ -73,7 +73,7 @@ max_request_size = 20_000_000 # in bytes # Enable jaeger to support monitoring and troubleshooting through jaeger. #allow_jaeger = false -#cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index 8c7f80d..bf5aa2b 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -14,7 +14,7 @@ impl DatabaseEngine for Engine { Ok(Arc::new(Engine( sled::Config::default() .path(&config.database_path) - .cache_capacity(config.cache_capacity as u64) + .cache_capacity(config.db_cache_capacity as u64) .use_compression(true) .open()?, ))) From b89cffed34c5af4d4dfcf376b3a91572edf50313 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Mon, 12 Jul 2021 21:23:20 +0200 Subject: [PATCH 24/28] warn on deprecated keys --- src/database.rs | 42 +++++++++++++++--------------- src/database/abstraction/sqlite.rs | 2 +- src/main.rs | 5 ++-- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/database.rs b/src/database.rs index 20049d1..a30982e 100644 --- a/src/database.rs +++ b/src/database.rs @@ -26,9 +26,9 @@ use rocket::{ try_outcome, State, }; use ruma::{DeviceId, ServerName, UserId}; -use serde::Deserialize; +use serde::{de::IgnoredAny, Deserialize}; use std::{ - collections::HashMap, + collections::{BTreeMap, HashMap}, fs::{self, remove_dir_all}, io::Write, ops::Deref, @@ -42,8 +42,8 @@ use self::proxy::ProxyConfig; pub struct Config { server_name: Box, database_path: String, - cache_capacity: Option, // deprecated - db_cache_capacity: Option, + #[serde(default = "default_db_cache_capacity")] + db_cache_capacity: u32, #[serde(default = "default_sqlite_read_pool_size")] sqlite_read_pool_size: usize, #[serde(default = "false_fn")] @@ -71,28 +71,28 @@ pub struct Config { trusted_servers: Vec>, #[serde(default = "default_log")] pub log: String, + + #[serde(flatten)] + catchall: BTreeMap, } -macro_rules! deprecate_with { - ($self:expr ; $from:ident -> $to:ident) => { - if let Some(v) = $self.$from { - let from = stringify!($from); - let to = stringify!($to); - log::warn!("{} is deprecated, use {} instead", from, to); +const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; - $self.$to.get_or_insert(v); +impl Config { + pub fn warn_deprecated(&self) { + let mut was_deprecated = false; + for key in self + .catchall + .keys() + .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) + { + log::warn!("Config parameter {} is deprecated", key); + was_deprecated = true; } - }; - ($self:expr ; $from:ident -> $to:ident or $default:expr) => { - deprecate_with!($self ; $from -> $to); - $self.$to.get_or_insert_with($default); - }; -} -impl Config { - pub fn process_fallbacks(&mut self) { - // TODO: have a proper way handle into above struct (maybe serde supports something like this?) - deprecate_with!(self ; cache_capacity -> db_cache_capacity or default_db_cache_capacity); + if was_deprecated { + log::warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted"); + } } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 6864287..d4ab9ad 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -128,7 +128,7 @@ impl DatabaseEngine for Engine { let pool = Pool::new( Path::new(&config.database_path).join("conduit.db"), config.sqlite_read_pool_size, - config.db_cache_capacity.expect("fallbacks hasn't been called") / 1024, // bytes -> kb + config.db_cache_capacity / 1024, // bytes -> kb )?; pool.write_lock() diff --git a/src/main.rs b/src/main.rs index 84dc4cb..e0d2e3d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -196,7 +196,7 @@ async fn main() { std::env::set_var("RUST_LOG", "warn"); - let mut config = raw_config + let config = raw_config .extract::() .expect("It looks like your config is invalid. Please take a look at the error"); @@ -218,8 +218,7 @@ async fn main() { tracing_subscriber::fmt::init(); } - // Required here to process fallbacks while logging is enabled, but before config is actually used for anything - config.process_fallbacks(); + config.warn_deprecated(); let db = Database::load_or_create(config) .await From 3b594d0b0a6d59bdcc600e2b65839513f6ec3047 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 17:22:36 +0200 Subject: [PATCH 25/28] add more documentation --- DEPLOY.md | 22 +++++++++++++++++++++- conduit-example.toml | 22 +++++++++++++++++++++- debian/postinst | 22 +++++++++++++++++++++- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/DEPLOY.md b/DEPLOY.md index 52464e1..7028648 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -106,11 +106,31 @@ allow_federation = true trusted_servers = ["matrix.org"] -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #workers = 4 # default: cpu core count * 2 address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy + +## sqlite + +# The amount of memory that the database will use, with the following formula; +# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 + +# How many permanent read connections will be open to the database, +# increase this if you see "creating spillover reader" in your logs. +#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) + +# If the database WAL (conduit.db-wal file) should be cleaned on a timer. +#sqlite_wal_clean_timer = false + +# How many seconds should pass before the WAL clean task should fire. +# Note: Is dependant on sqlite_wal_clean_timer being true. +#sqlite_wal_clean_second_interval = 60 + +# How long the WAL clean task should (in seconds) try to wait while +# getting exclusive access to the database (before giving up). +#sqlite_wal_clean_second_timeout = 2 ``` ## Setting the correct file permissions diff --git a/conduit-example.toml b/conduit-example.toml index 2df8854..cbcc58c 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -35,7 +35,6 @@ max_request_size = 20_000_000 # in bytes trusted_servers = ["matrix.org"] -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 @@ -43,3 +42,24 @@ trusted_servers = ["matrix.org"] address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy proxy = "none" # more examples can be found at src/database/proxy.rs:6 + +## sqlite + +# The amount of memory that the database will use, with the following formula; +# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 + +# How many permanent read connections will be open to the database, +# increase this if you see "creating spillover reader" in your logs. +#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) + +# If the database WAL (conduit.db-wal file) should be cleaned on a timer. +#sqlite_wal_clean_timer = false + +# How many seconds should pass before the WAL clean task should fire. +# Note: Is dependant on sqlite_wal_clean_timer being true. +#sqlite_wal_clean_second_interval = 60 + +# How long the WAL clean task should (in seconds) try to wait while +# getting exclusive access to the database (before giving up). +#sqlite_wal_clean_second_timeout = 2 \ No newline at end of file diff --git a/debian/postinst b/debian/postinst index 4c33e7d..5a05485 100644 --- a/debian/postinst +++ b/debian/postinst @@ -73,10 +73,30 @@ max_request_size = 20_000_000 # in bytes # Enable jaeger to support monitoring and troubleshooting through jaeger. #allow_jaeger = false -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 #max_concurrent_requests = 100 # How many requests Conduit sends to other servers at the same time #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 + +## sqlite + +# The amount of memory that the database will use, with the following formula; +# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes +#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 + +# How many permanent read connections will be open to the database, +# increase this if you see "creating spillover reader" in your logs. +#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) + +# If the database WAL (conduit.db-wal file) should be cleaned on a timer. +#sqlite_wal_clean_timer = false + +# How many seconds should pass before the WAL clean task should fire. +# Note: Is dependant on sqlite_wal_clean_timer being true. +#sqlite_wal_clean_second_interval = 60 + +# How long the WAL clean task should (in seconds) try to wait while +# getting exclusive access to the database (before giving up). +#sqlite_wal_clean_second_timeout = 2 EOF fi ;; From cfc61eb35ea80a08bdfd3c5aa1e7687e51820189 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 17:44:49 +0200 Subject: [PATCH 26/28] exit early when only sled db is detected --- src/database.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/database.rs b/src/database.rs index a30982e..31deafd 100644 --- a/src/database.rs +++ b/src/database.rs @@ -32,6 +32,7 @@ use std::{ fs::{self, remove_dir_all}, io::Write, ops::Deref, + path::Path, sync::{Arc, RwLock}, }; use tokio::sync::{OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; @@ -170,8 +171,31 @@ impl Database { Ok(()) } + fn check_sled_or_sqlite_db(config: &Config) { + let path = Path::new(&config.database_path); + + let sled_exists = path.join("db").exists(); + let sqlite_exists = path.join("conduit.db").exists(); + + if sled_exists { + if sqlite_exists { + // most likely an in-place directory, only warn + log::warn!("both sled and sqlite databases are detected in database directory"); + log::warn!("currently running from the sqlite database, but consider removing sled database files to free up space") + } else { + log::error!( + "sled database detected, conduit now uses sqlite for database operations" + ); + log::error!("this database must be converted to sqlite, go to https://github.com/ShadowJonathan/conduit_toolbox#conduit_sled_to_sqlite"); + std::process::exit(1); + } + } + } + /// Load an existing database or create a new one. pub async fn load_or_create(config: Config) -> Result>> { + Self::check_sled_or_sqlite_db(&config); + let builder = Engine::open(&config)?; if config.max_request_size < 1024 { From 251d19d06ca5ec4e125b12ff81692fc89aaacc69 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 18:10:08 +0200 Subject: [PATCH 27/28] remove extra advanced config and make db_cache_capacity_mb total --- DEPLOY.md | 22 ++-------------------- conduit-example.toml | 22 ++-------------------- debian/postinst | 22 ++-------------------- src/database.rs | 8 ++++---- src/database/abstraction/sled.rs | 2 +- src/database/abstraction/sqlite.rs | 10 ++++++++-- 6 files changed, 19 insertions(+), 67 deletions(-) diff --git a/DEPLOY.md b/DEPLOY.md index 7028648..60650f6 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -111,26 +111,8 @@ trusted_servers = ["matrix.org"] address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy -## sqlite - -# The amount of memory that the database will use, with the following formula; -# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 - -# How many permanent read connections will be open to the database, -# increase this if you see "creating spillover reader" in your logs. -#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) - -# If the database WAL (conduit.db-wal file) should be cleaned on a timer. -#sqlite_wal_clean_timer = false - -# How many seconds should pass before the WAL clean task should fire. -# Note: Is dependant on sqlite_wal_clean_timer being true. -#sqlite_wal_clean_second_interval = 60 - -# How long the WAL clean task should (in seconds) try to wait while -# getting exclusive access to the database (before giving up). -#sqlite_wal_clean_second_timeout = 2 +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 10 ``` ## Setting the correct file permissions diff --git a/conduit-example.toml b/conduit-example.toml index cbcc58c..6fb4d5f 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -43,23 +43,5 @@ address = "127.0.0.1" # This makes sure Conduit can only be reached using the re proxy = "none" # more examples can be found at src/database/proxy.rs:6 -## sqlite - -# The amount of memory that the database will use, with the following formula; -# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 - -# How many permanent read connections will be open to the database, -# increase this if you see "creating spillover reader" in your logs. -#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) - -# If the database WAL (conduit.db-wal file) should be cleaned on a timer. -#sqlite_wal_clean_timer = false - -# How many seconds should pass before the WAL clean task should fire. -# Note: Is dependant on sqlite_wal_clean_timer being true. -#sqlite_wal_clean_second_interval = 60 - -# How long the WAL clean task should (in seconds) try to wait while -# getting exclusive access to the database (before giving up). -#sqlite_wal_clean_second_timeout = 2 \ No newline at end of file +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 10 \ No newline at end of file diff --git a/debian/postinst b/debian/postinst index 5a05485..79b7e73 100644 --- a/debian/postinst +++ b/debian/postinst @@ -77,26 +77,8 @@ max_request_size = 20_000_000 # in bytes #log = "info,state_res=warn,rocket=off,_=off,sled=off" #workers = 4 # default: cpu core count * 2 -## sqlite - -# The amount of memory that the database will use, with the following formula; -# (db_cache_capacity * (sqlite_read_pool_size + 1)), in bytes -#db_cache_capacity = 1073741824 # in bytes, 1024 * 1024 * 1024 - -# How many permanent read connections will be open to the database, -# increase this if you see "creating spillover reader" in your logs. -#sqlite_read_pool_size = 2 # default: max(cpu core count, 1) - -# If the database WAL (conduit.db-wal file) should be cleaned on a timer. -#sqlite_wal_clean_timer = false - -# How many seconds should pass before the WAL clean task should fire. -# Note: Is dependant on sqlite_wal_clean_timer being true. -#sqlite_wal_clean_second_interval = 60 - -# How long the WAL clean task should (in seconds) try to wait while -# getting exclusive access to the database (before giving up). -#sqlite_wal_clean_second_timeout = 2 +# The total amount of memory that the database will use. +#db_cache_capacity_mb = 10 EOF fi ;; diff --git a/src/database.rs b/src/database.rs index 31deafd..c498841 100644 --- a/src/database.rs +++ b/src/database.rs @@ -43,8 +43,8 @@ use self::proxy::ProxyConfig; pub struct Config { server_name: Box, database_path: String, - #[serde(default = "default_db_cache_capacity")] - db_cache_capacity: u32, + #[serde(default = "default_db_cache_capacity_mb")] + db_cache_capacity_mb: f64, #[serde(default = "default_sqlite_read_pool_size")] sqlite_read_pool_size: usize, #[serde(default = "false_fn")] @@ -105,8 +105,8 @@ fn true_fn() -> bool { true } -fn default_db_cache_capacity() -> u32 { - 1024 * 1024 * 1024 +fn default_db_cache_capacity_mb() -> f64 { + 10.0 } fn default_sqlite_read_pool_size() -> usize { diff --git a/src/database/abstraction/sled.rs b/src/database/abstraction/sled.rs index bf5aa2b..271be1e 100644 --- a/src/database/abstraction/sled.rs +++ b/src/database/abstraction/sled.rs @@ -14,7 +14,7 @@ impl DatabaseEngine for Engine { Ok(Arc::new(Engine( sled::Config::default() .path(&config.database_path) - .cache_capacity(config.db_cache_capacity as u64) + .cache_capacity((config.db_cache_capacity_mb * 1024 * 1024) as u64) .use_compression(true) .open()?, ))) diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index d4ab9ad..22a5559 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -58,7 +58,13 @@ impl<'a> Deref for HoldingConn<'a> { } impl Pool { - fn new>(path: P, num_readers: usize, cache_size: u32) -> Result { + fn new>(path: P, num_readers: usize, total_cache_size_mb: f64) -> Result { + // calculates cache-size per permanent connection + // 1. convert MB to KiB + // 2. divide by permanent connections + // 3. round down to nearest integer + let cache_size: u32 = ((total_cache_size_mb * 1024.0) / (num_readers + 1) as f64) as u32; + let writer = Mutex::new(Self::prepare_conn(&path, Some(cache_size))?); let mut readers = Vec::new(); @@ -128,7 +134,7 @@ impl DatabaseEngine for Engine { let pool = Pool::new( Path::new(&config.database_path).join("conduit.db"), config.sqlite_read_pool_size, - config.db_cache_capacity / 1024, // bytes -> kb + config.db_cache_capacity_mb, )?; pool.write_lock() From 6ad3108af2bdd0066454ae8e6094ce2505042f74 Mon Sep 17 00:00:00 2001 From: Jonathan de Jong Date: Tue, 13 Jul 2021 19:50:45 +0200 Subject: [PATCH 28/28] incorperate feedback --- DEPLOY.md | 2 +- conduit-example.toml | 2 +- debian/postinst | 2 +- src/database.rs | 44 +++++++++++++++++++++++------------------ src/database/globals.rs | 3 +++ 5 files changed, 31 insertions(+), 22 deletions(-) diff --git a/DEPLOY.md b/DEPLOY.md index 60650f6..ba7a030 100644 --- a/DEPLOY.md +++ b/DEPLOY.md @@ -112,7 +112,7 @@ trusted_servers = ["matrix.org"] address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy # The total amount of memory that the database will use. -#db_cache_capacity_mb = 10 +#db_cache_capacity_mb = 200 ``` ## Setting the correct file permissions diff --git a/conduit-example.toml b/conduit-example.toml index 6fb4d5f..d184991 100644 --- a/conduit-example.toml +++ b/conduit-example.toml @@ -44,4 +44,4 @@ address = "127.0.0.1" # This makes sure Conduit can only be reached using the re proxy = "none" # more examples can be found at src/database/proxy.rs:6 # The total amount of memory that the database will use. -#db_cache_capacity_mb = 10 \ No newline at end of file +#db_cache_capacity_mb = 200 \ No newline at end of file diff --git a/debian/postinst b/debian/postinst index 79b7e73..824fd64 100644 --- a/debian/postinst +++ b/debian/postinst @@ -78,7 +78,7 @@ max_request_size = 20_000_000 # in bytes #workers = 4 # default: cpu core count * 2 # The total amount of memory that the database will use. -#db_cache_capacity_mb = 10 +#db_cache_capacity_mb = 200 EOF fi ;; diff --git a/src/database.rs b/src/database.rs index c498841..ac17372 100644 --- a/src/database.rs +++ b/src/database.rs @@ -47,7 +47,7 @@ pub struct Config { db_cache_capacity_mb: f64, #[serde(default = "default_sqlite_read_pool_size")] sqlite_read_pool_size: usize, - #[serde(default = "false_fn")] + #[serde(default = "true_fn")] sqlite_wal_clean_timer: bool, #[serde(default = "default_sqlite_wal_clean_second_interval")] sqlite_wal_clean_second_interval: u32, @@ -106,7 +106,7 @@ fn true_fn() -> bool { } fn default_db_cache_capacity_mb() -> f64 { - 10.0 + 200.0 } fn default_sqlite_read_pool_size() -> usize { @@ -114,7 +114,7 @@ fn default_sqlite_read_pool_size() -> usize { } fn default_sqlite_wal_clean_second_interval() -> u32 { - 60 + 60 * 60 } fn default_sqlite_wal_clean_second_timeout() -> u32 { @@ -171,30 +171,36 @@ impl Database { Ok(()) } - fn check_sled_or_sqlite_db(config: &Config) { + fn check_sled_or_sqlite_db(config: &Config) -> Result<()> { let path = Path::new(&config.database_path); - let sled_exists = path.join("db").exists(); - let sqlite_exists = path.join("conduit.db").exists(); - - if sled_exists { - if sqlite_exists { - // most likely an in-place directory, only warn - log::warn!("both sled and sqlite databases are detected in database directory"); - log::warn!("currently running from the sqlite database, but consider removing sled database files to free up space") - } else { - log::error!( - "sled database detected, conduit now uses sqlite for database operations" - ); - log::error!("this database must be converted to sqlite, go to https://github.com/ShadowJonathan/conduit_toolbox#conduit_sled_to_sqlite"); - std::process::exit(1); + #[cfg(feature = "backend_sqlite")] + { + let sled_exists = path.join("db").exists(); + let sqlite_exists = path.join("conduit.db").exists(); + if sled_exists { + if sqlite_exists { + // most likely an in-place directory, only warn + log::warn!("Both sled and sqlite databases are detected in database directory"); + log::warn!("Currently running from the sqlite database, but consider removing sled database files to free up space") + } else { + log::error!( + "Sled database detected, conduit now uses sqlite for database operations" + ); + log::error!("This database must be converted to sqlite, go to https://github.com/ShadowJonathan/conduit_toolbox#conduit_sled_to_sqlite"); + return Err(Error::bad_config( + "sled database detected, migrate to sqlite", + )); + } } } + + Ok(()) } /// Load an existing database or create a new one. pub async fn load_or_create(config: Config) -> Result>> { - Self::check_sled_or_sqlite_db(&config); + Self::check_sled_or_sqlite_db(&config)?; let builder = Engine::open(&config)?; diff --git a/src/database/globals.rs b/src/database/globals.rs index 7c53072..4242cf5 100644 --- a/src/database/globals.rs +++ b/src/database/globals.rs @@ -84,6 +84,9 @@ impl ServerCertVerifier for MatrixServerVerifier { } } +/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. +/// +/// This is utilized to have sync workers return early and release read locks on the database. pub struct RotationHandler(broadcast::Sender<()>, broadcast::Receiver<()>); impl RotationHandler {