diff options
author | Timo Kösters <timo@koesters.xyz> | 2022-10-05 15:33:57 +0200 |
---|---|---|
committer | Nyaaori <+@nyaaori.cat> | 2022-10-10 14:02:00 +0200 |
commit | cff52d7ebb5066f3d8e513488b84a431c0093e65 (patch) | |
tree | 597e030b6f52c5282625a51fd0d7e0e799ea7e00 /src/database | |
parent | face766e0f32481fd97a435f1ed8579d8cfc634c (diff) | |
download | conduit-cff52d7ebb5066f3d8e513488b84a431c0093e65.zip |
messing around with arcs
Diffstat (limited to 'src/database')
30 files changed, 344 insertions, 176 deletions
diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 49c9170..f0325d2 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,19 +1,19 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId}; use serde::{Serialize, de::DeserializeOwned}; use crate::{Result, database::KeyValueDatabase, service, Error, utils, services}; -impl service::account_data::Data for KeyValueDatabase { +impl service::account_data::Data for Arc<KeyValueDatabase> { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update<T: Serialize>( + fn update( &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, - data: &T, + data: &serde_json::Value, ) -> Result<()> { let mut prefix = room_id .map(|r| r.to_string()) @@ -32,8 +32,7 @@ impl service::account_data::Data for KeyValueDatabase { let mut key = prefix; key.extend_from_slice(event_type.to_string().as_bytes()); - let json = serde_json::to_value(data).expect("all types here can be serialized"); // TODO: maybe add error handling - if json.get("type").is_none() || json.get("content").is_none() { + if data.get("type").is_none() || data.get("content").is_none() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Account data doesn't have all required fields.", @@ -42,7 +41,7 @@ impl service::account_data::Data for KeyValueDatabase { self.roomuserdataid_accountdata.insert( &roomuserdataid, - &serde_json::to_vec(&json).expect("to_vec always works on json values"), + &serde_json::to_vec(&data).expect("to_vec always works on json values"), )?; let prev = self.roomusertype_roomuserdataid.get(&key)?; @@ -60,12 +59,12 @@ impl service::account_data::Data for KeyValueDatabase { /// Searches the account data for a specific kind. #[tracing::instrument(skip(self, room_id, user_id, kind))] - fn get<T: DeserializeOwned>( + fn get( &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, - ) -> Result<Option<T>> { + ) -> Result<Option<Box<serde_json::value::RawValue>>> { let mut key = room_id .map(|r| r.to_string()) .unwrap_or_default() diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index f427ba7..ee6ae20 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index e665229..8711920 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -1,8 +1,136 @@ -use ruma::signatures::Ed25519KeyPair; +use std::{collections::BTreeMap, sync::Arc}; -use crate::{Result, service, database::KeyValueDatabase, Error, utils}; +use async_trait::async_trait; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use ruma::{signatures::Ed25519KeyPair, UserId, DeviceId, ServerName, api::federation::discovery::{ServerSigningKeys, VerifyKey}, ServerSigningKeyId, MilliSecondsSinceUnixEpoch}; + +use crate::{Result, service, database::KeyValueDatabase, Error, utils, services}; + +pub const COUNTER: &[u8] = b"c"; + +#[async_trait] +impl service::globals::Data for Arc<KeyValueDatabase> { + fn next_count(&self) -> Result<u64> { + utils::u64_from_bytes(&self.global.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + } + + fn current_count(&self) -> Result<u64> { + self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + }) + } + + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xff); + + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xff); + + let mut futures = FuturesUnordered::new(); + + // Return when *any* user changed his key + // TODO: only send for user they share a room with + futures.push( + self.todeviceid_events + .watch_prefix(&userdeviceid_prefix), + ); + + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_invitestate + .watch_prefix(&userid_prefix), + ); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push( + self.userroomid_notificationcount + .watch_prefix(&userid_prefix), + ); + futures.push( + self.userroomid_highlightcount + .watch_prefix(&userid_prefix), + ); + + // Events for rooms we are in + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { + let short_roomid = services() + .rooms + .short + .get_shortroomid(&room_id) + .ok() + .flatten() + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xff); + + // PDUs + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + + // EDUs + futures.push( + self.roomid_lasttypingupdate + .watch_prefix(&roomid_bytes), + ); + + futures.push( + self.readreceiptid_readreceipt + .watch_prefix(&roomid_prefix), + ); + + // Key changes + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); + + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&roomuser_prefix), + ); + } + + let mut globaluserdata_prefix = vec![0xff]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); + + futures.push( + self.roomusertype_roomuserdataid + .watch_prefix(&globaluserdata_prefix), + ); + + // More key changes (used when user is not joined to any rooms) + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); + + // One time keys + futures.push( + self.userid_lastonetimekeyupdate + .watch_prefix(&userid_bytes), + ); + + futures.push(Box::pin(services().globals.rotate.watch())); + + // Wait until one of them finds something + futures.next().await; + + Ok(()) + } + + fn cleanup(&self) -> Result<()> { + self._db.cleanup() + } + + fn memory_usage(&self) -> Result<String> { + self._db.memory_usage() + } -impl service::globals::Data for KeyValueDatabase { fn load_keypair(&self) -> Result<Ed25519KeyPair> { let keypair_bytes = self.global.get(b"keypair")?.map_or_else( || { @@ -39,4 +167,81 @@ impl service::globals::Data for KeyValueDatabase { fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } + + fn add_signing_key( + &self, + origin: &ServerName, + new_keys: ServerSigningKeys, + ) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + + let mut keys = signingkeys + .and_then(|keys| serde_json::from_slice(&keys).ok()) + .unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); + + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; + + keys.verify_keys.extend(verify_keys.into_iter()); + keys.old_verify_keys.extend(old_verify_keys.into_iter()); + + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; + + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + + Ok(tree) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. + fn signing_keys_for( + &self, + origin: &ServerName, + ) -> Result<BTreeMap<Box<ServerSigningKeyId>, VerifyKey>> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .map(|keys: ServerSigningKeys| { + let mut tree = keys.verify_keys; + tree.extend( + keys.old_verify_keys + .into_iter() + .map(|old| (old.0, VerifyKey::new(old.1.key))), + ); + tree + }) + .unwrap_or_else(BTreeMap::new); + + Ok(signingkeys) + } + + fn database_version(&self) -> Result<u64> { + self.global.get(b"version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version) + .map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.global + .insert(b"version", &new_version.to_be_bytes())?; + Ok(()) + } + + } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index 8171451..c59ed36 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,10 +1,10 @@ -use std::collections::BTreeMap; +use std::{collections::BTreeMap, sync::Arc}; use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId}; use crate::{Result, service, database::KeyValueDatabase, services, Error, utils}; -impl service::key_backups::Data for KeyValueDatabase { +impl service::key_backups::Data for Arc<KeyValueDatabase> { fn create_backup( &self, user_id: &UserId, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index a84cbd5..1726755 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -1,9 +1,11 @@ +use std::sync::Arc; + use ruma::api::client::error::ErrorKind; use crate::{database::KeyValueDatabase, service, Error, utils, Result}; -impl service::media::Data for KeyValueDatabase { - fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: &Option<&str>, content_type: &Option<&str>) -> Result<Vec<u8>> { +impl service::media::Data for Arc<KeyValueDatabase> { + fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result<Vec<u8>> { let mut key = mxc.as_bytes().to_vec(); key.push(0xff); key.extend_from_slice(&width.to_be_bytes()); diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index b05e47b..85d1d86 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; use crate::{service, database::KeyValueDatabase, Error, Result}; -impl service::pusher::Data for KeyValueDatabase { +impl service::pusher::Data for Arc<KeyValueDatabase> { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 0aa8dd4..437902d 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::alias::Data for KeyValueDatabase { +impl service::rooms::alias::Data for Arc<KeyValueDatabase> { fn set_alias( &self, alias: &RoomAliasId, diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 888d472..2dffb04 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -1,28 +1,60 @@ -use std::{collections::HashSet, mem::size_of}; +use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service, database::KeyValueDatabase, Result, utils}; -impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, shorteventid: u64) -> Result<Option<HashSet<u64>>> { - Ok(self.shorteventid_authchain - .get(&shorteventid.to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::<u64>()) - .map(|chunk| { - utils::u64_from_bytes(chunk).expect("byte length is correct") - }) - .collect() - })) +impl service::rooms::auth_chain::Data for Arc<KeyValueDatabase> { + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } + + // We only save auth chains for single events in the db + if key.len() == 1 { + // Check DB cache + let chain = self.shorteventid_authchain + .get(&key[0].to_be_bytes())? + .map(|chain| { + chain + .chunks_exact(size_of::<u64>()) + .map(|chunk| { + utils::u64_from_bytes(chunk).expect("byte length is correct") + }) + .collect() + }); + + if let Some(chain) = chain { + let chain = Arc::new(chain); + + // Cache in RAM + self.auth_chain_cache + .lock() + .unwrap() + .insert(vec![key[0]], Arc::clone(&chain)); + + return Ok(Some(chain)); + } + } + + Ok(None) + } - fn cache_eventid_authchain(&self, shorteventid: u64, auth_chain: &HashSet<u64>) -> Result<()> { - self.shorteventid_authchain.insert( - &shorteventid.to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::<Vec<u8>>(), - ) + fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::<Vec<u8>>(), + )?; + } + + // Cache in RAM + self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); + + Ok(()) } } diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 727004e..864e75e 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::directory::Data for KeyValueDatabase { +impl service::rooms::directory::Data for Arc<KeyValueDatabase> { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs index b5007f8..03e4219 100644 --- a/src/database/key_value/rooms/edus/mod.rs +++ b/src/database/key_value/rooms/edus/mod.rs @@ -2,6 +2,8 @@ mod presence; mod typing; mod read_receipt; +use std::sync::Arc; + use crate::{service, database::KeyValueDatabase}; -impl service::rooms::edus::Data for KeyValueDatabase {} +impl service::rooms::edus::Data for Arc<KeyValueDatabase> {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 1477c28..5aeb147 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::Arc}; use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::edus::presence::Data for KeyValueDatabase { +impl service::rooms::edus::presence::Data for Arc<KeyValueDatabase> { fn update_presence( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index a12e265..7fcb8ac 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,10 +1,10 @@ -use std::mem; +use std::{mem, sync::Arc}; use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { +impl service::rooms::edus::read_receipt::Data for Arc<KeyValueDatabase> { fn readreceipt_update( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index b7d3596..7f3526d 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -1,10 +1,10 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; use ruma::{UserId, RoomId}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::typing::Data for KeyValueDatabase { +impl service::rooms::edus::typing::Data for Arc<KeyValueDatabase> { fn typing_add( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index 133e1d0..b16657a 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, RoomId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::lazy_loading::Data for KeyValueDatabase { +impl service::rooms::lazy_loading::Data for Arc<KeyValueDatabase> { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index db2bc69..560beb9 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::RoomId; use crate::{service, database::KeyValueDatabase, Result, services}; -impl service::rooms::metadata::Data for KeyValueDatabase { +impl service::rooms::metadata::Data for Arc<KeyValueDatabase> { fn exists(&self, room_id: &RoomId) -> Result<bool> { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index 406943e..97c29e5 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -15,6 +15,8 @@ mod state_compressor; mod timeline; mod user; +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service}; -impl service::rooms::Data for KeyValueDatabase {} +impl service::rooms::Data for Arc<KeyValueDatabase> {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index aa97544..b1ae816 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result}; -impl service::rooms::outlier::Data for KeyValueDatabase { +impl service::rooms::outlier::Data for Arc<KeyValueDatabase> { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index f3ac414..f5e8f76 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -4,7 +4,7 @@ use ruma::{RoomId, EventId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::pdu_metadata::Data for KeyValueDatabase { +impl service::rooms::pdu_metadata::Data for Arc<KeyValueDatabase> { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index dfbdbc6..7b8d278 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,10 +1,10 @@ -use std::mem::size_of; +use std::{mem::size_of, sync::Arc}; use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Result, services}; -impl service::rooms::search::Data for KeyValueDatabase { +impl service::rooms::search::Data for Arc<KeyValueDatabase> { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 9129638..9a302b5 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,4 +1,6 @@ +use std::sync::Arc; + use crate::{database::KeyValueDatabase, service}; -impl service::rooms::short::Data for KeyValueDatabase { +impl service::rooms::short::Data for Arc<KeyValueDatabase> { } diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 405939d..527c240 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,11 +1,12 @@ use ruma::{RoomId, EventId}; +use tokio::sync::MutexGuard; use std::sync::Arc; -use std::{sync::MutexGuard, collections::HashSet}; +use std::collections::HashSet; use std::fmt::Debug; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::state::Data for KeyValueDatabase { +impl service::rooms::state::Data for Arc<KeyValueDatabase> { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { self.roomid_shortstatehash .get(room_id.as_bytes())? @@ -48,7 +49,7 @@ impl service::rooms::state::Data for KeyValueDatabase { fn set_forward_extremities<'a>( &self, room_id: &RoomId, - event_ids: impl IntoIterator<Item = &'a EventId> + Debug, + event_ids: &mut dyn Iterator<Item = &'a EventId>, _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex ) -> Result<()> { let mut prefix = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 4d5bd4a..9af45db 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; #[async_trait] -impl service::rooms::state_accessor::Data for KeyValueDatabase { +impl service::rooms::state_accessor::Data for Arc<KeyValueDatabase> { async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 5f05485..bdb8cf8 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw}; use crate::{service, database::KeyValueDatabase, services, Result}; -impl service::rooms::state_cache::Data for KeyValueDatabase { +impl service::rooms::state_cache::Data for Arc<KeyValueDatabase> { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index aee1890..e1c0280 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, mem::size_of}; +use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result}; -impl service::rooms::state_compressor::Data for KeyValueDatabase { +impl service::rooms::state_compressor::Data for Arc<KeyValueDatabase> { fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { let value = self .shortstatehash_statediff diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index a3b6c17..2d334b9 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -5,7 +5,7 @@ use tracing::error; use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services}; -impl service::rooms::timeline::Data for KeyValueDatabase { +impl service::rooms::timeline::Data for Arc<KeyValueDatabase> { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { match self .lasttimelinecount_cache diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 66681e3..4d20b00 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, RoomId}; use crate::{service, database::KeyValueDatabase, utils, Error, Result, services}; -impl service::rooms::user::Data for KeyValueDatabase { +impl service::rooms::user::Data for Arc<KeyValueDatabase> { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index a63b3c5..7fa6908 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, TransactionId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::transaction_ids::Data for KeyValueDatabase { +impl service::transaction_ids::Data for Arc<KeyValueDatabase> { fn add_txnid( &self, user_id: &UserId, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index cf242de..8752e55 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}}; use crate::{database::KeyValueDatabase, service, Error, Result}; -impl service::uiaa::Data for KeyValueDatabase { +impl service::uiaa::Data for Arc<KeyValueDatabase> { fn set_uiaa_request( &self, user_id: &UserId, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 338d880..1ac85b3 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,11 +1,11 @@ -use std::{mem::size_of, collections::BTreeMap}; +use std::{mem::size_of, collections::BTreeMap, sync::Arc}; use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; use tracing::warn; use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result}; -impl service::users::Data for KeyValueDatabase { +impl service::users::Data for Arc<KeyValueDatabase> { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) @@ -687,10 +687,10 @@ impl service::users::Data for KeyValueDatabase { }) } - fn get_master_key<F: Fn(&UserId) -> bool>( + fn get_master_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { self.userid_masterkeyid .get(user_id.as_bytes())? @@ -708,10 +708,10 @@ impl service::users::Data for KeyValueDatabase { }) } - fn get_self_signing_key<F: Fn(&UserId) -> bool>( + fn get_self_signing_key( &self, user_id: &UserId, - allowed_signatures: F, + allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { self.userid_selfsigningkeyid .get(user_id.as_bytes())? diff --git a/src/database/mod.rs b/src/database/mod.rs index aa5c583..35922f0 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -402,10 +402,10 @@ impl KeyValueDatabase { }); - let services_raw = Services::build(Arc::clone(&db)); + let services_raw = Box::new(Services::build(Arc::clone(&db))); // This is the first and only time we initialize the SERVICE static - *SERVICES.write().unwrap() = Some(services_raw); + *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); // Matrix resource ownership is based on the server name; changing it @@ -877,105 +877,6 @@ impl KeyValueDatabase { services().globals.rotate.fire(); } - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xff); - - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xff); - - let mut futures = FuturesUnordered::new(); - - // Return when *any* user changed his key - // TODO: only send for user they share a room with - futures.push( - self.todeviceid_events - .watch_prefix(&userdeviceid_prefix), - ); - - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_invitestate - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push( - self.userroomid_highlightcount - .watch_prefix(&userid_prefix), - ); - - // Events for rooms we are in - for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(|r| r.ok()) { - let short_roomid = services() - .rooms - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xff); - - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - - // EDUs - futures.push( - self.roomid_lasttypingupdate - .watch_prefix(&roomid_bytes), - ); - - futures.push( - self.readreceiptid_readreceipt - .watch_prefix(&roomid_prefix), - ); - - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } - - let mut globaluserdata_prefix = vec![0xff]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); - - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); - - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - - // One time keys - futures.push( - self.userid_lastonetimekeyupdate - .watch_prefix(&userid_bytes), - ); - - futures.push(Box::pin(services().globals.rotate.watch())); - - // Wait until one of them finds something - futures.next().await; - } - #[tracing::instrument(skip(self))] pub fn flush(&self) -> Result<()> { let start = std::time::Instant::now(); @@ -1021,7 +922,7 @@ impl KeyValueDatabase { } let start = Instant::now(); - if let Err(e) = services().globals.db._db.cleanup() { + if let Err(e) = services().globals.cleanup() { error!("cleanup: Errored: {}", e); } else { info!("cleanup: Finished in {:?}", start.elapsed()); @@ -1048,9 +949,9 @@ fn set_emergency_access() -> Result<bool> { None, &conduit_user, GlobalAccountDataEventType::PushRules.to_string().into(), - &GlobalAccountDataEvent { + &serde_json::to_value(&GlobalAccountDataEvent { content: PushRulesEventContent { global: ruleset }, - }, + }).expect("to json value always works"), )?; res |