diff options
author | Timo Kösters <timo@koesters.xyz> | 2022-10-05 18:36:12 +0200 |
---|---|---|
committer | Nyaaori <+@nyaaori.cat> | 2022-10-10 14:02:00 +0200 |
commit | 44fe6d1554eaa0a15314686974ab01f48c836588 (patch) | |
tree | 742d3e844c32acc6fa1c6616d0ff440aff8a6e6c /src/database/key_value | |
parent | cff52d7ebb5066f3d8e513488b84a431c0093e65 (diff) | |
download | conduit-44fe6d1554eaa0a15314686974ab01f48c836588.zip |
127 errors left
Diffstat (limited to 'src/database/key_value')
29 files changed, 296 insertions, 69 deletions
diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index f0325d2..5674ac0 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,11 +1,11 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId}; use serde::{Serialize, de::DeserializeOwned}; use crate::{Result, database::KeyValueDatabase, service, Error, utils, services}; -impl service::account_data::Data for Arc<KeyValueDatabase> { +impl service::account_data::Data for KeyValueDatabase { /// Places one event in the account data of the user and removes the previous entry. #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] fn update( diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index ee6ae20..f427ba7 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index 8711920..199cbf6 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -1,4 +1,4 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; use async_trait::async_trait; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -9,7 +9,7 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils, services} pub const COUNTER: &[u8] = b"c"; #[async_trait] -impl service::globals::Data for Arc<KeyValueDatabase> { +impl service::globals::Data for KeyValueDatabase { fn next_count(&self) -> Result<u64> { utils::u64_from_bytes(&self.global.increment(COUNTER)?) .map_err(|_| Error::bad_database("Count has invalid bytes.")) diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index c59ed36..8171451 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,10 +1,10 @@ -use std::{collections::BTreeMap, sync::Arc}; +use std::collections::BTreeMap; use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId}; use crate::{Result, service, database::KeyValueDatabase, services, Error, utils}; -impl service::key_backups::Data for Arc<KeyValueDatabase> { +impl service::key_backups::Data for KeyValueDatabase { fn create_backup( &self, user_id: &UserId, diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 1726755..f024487 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::api::client::error::ErrorKind; use crate::{database::KeyValueDatabase, service, Error, utils, Result}; -impl service::media::Data for Arc<KeyValueDatabase> { +impl service::media::Data for KeyValueDatabase { fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result<Vec<u8>> { let mut key = mxc.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 85d1d86..b05e47b 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, api::client::push::{set_pusher, get_pushers}}; use crate::{service, database::KeyValueDatabase, Error, Result}; -impl service::pusher::Data for Arc<KeyValueDatabase> { +impl service::pusher::Data for KeyValueDatabase { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { let mut key = sender.as_bytes().to_vec(); key.push(0xff); diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 437902d..0aa8dd4 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::alias::Data for Arc<KeyValueDatabase> { +impl service::rooms::alias::Data for KeyValueDatabase { fn set_alias( &self, alias: &RoomAliasId, diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 2dffb04..49d3956 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -2,7 +2,7 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{service, database::KeyValueDatabase, Result, utils}; -impl service::rooms::auth_chain::Data for Arc<KeyValueDatabase> { +impl service::rooms::auth_chain::Data for KeyValueDatabase { fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index 864e75e..727004e 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::directory::Data for Arc<KeyValueDatabase> { +impl service::rooms::directory::Data for KeyValueDatabase { fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs index 03e4219..b5007f8 100644 --- a/src/database/key_value/rooms/edus/mod.rs +++ b/src/database/key_value/rooms/edus/mod.rs @@ -2,8 +2,6 @@ mod presence; mod typing; mod read_receipt; -use std::sync::Arc; - use crate::{service, database::KeyValueDatabase}; -impl service::rooms::edus::Data for Arc<KeyValueDatabase> {} +impl service::rooms::edus::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 5aeb147..1477c28 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,10 +1,10 @@ -use std::{collections::HashMap, sync::Arc}; +use std::collections::HashMap; use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt}; use crate::{service, database::KeyValueDatabase, utils, Error, services, Result}; -impl service::rooms::edus::presence::Data for Arc<KeyValueDatabase> { +impl service::rooms::edus::presence::Data for KeyValueDatabase { fn update_presence( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index 7fcb8ac..a12e265 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,10 +1,10 @@ -use std::{mem, sync::Arc}; +use std::mem; use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::read_receipt::Data for Arc<KeyValueDatabase> { +impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { fn readreceipt_update( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index 7f3526d..b7d3596 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -1,10 +1,10 @@ -use std::{collections::HashSet, sync::Arc}; +use std::collections::HashSet; use ruma::{UserId, RoomId}; use crate::{database::KeyValueDatabase, service, utils, Error, services, Result}; -impl service::rooms::edus::typing::Data for Arc<KeyValueDatabase> { +impl service::rooms::edus::typing::Data for KeyValueDatabase { fn typing_add( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index b16657a..133e1d0 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, RoomId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::lazy_loading::Data for Arc<KeyValueDatabase> { +impl service::rooms::lazy_loading::Data for KeyValueDatabase { fn lazy_load_was_sent_before( &self, user_id: &UserId, diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index 560beb9..72f6251 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::RoomId; use crate::{service, database::KeyValueDatabase, Result, services}; -impl service::rooms::metadata::Data for Arc<KeyValueDatabase> { +impl service::rooms::metadata::Data for KeyValueDatabase { fn exists(&self, room_id: &RoomId) -> Result<bool> { let prefix = match services().rooms.short.get_shortroomid(room_id)? { Some(b) => b.to_be_bytes().to_vec(), @@ -19,4 +17,18 @@ impl service::rooms::metadata::Data for Arc<KeyValueDatabase> { .filter(|(k, _)| k.starts_with(&prefix)) .is_some()) } + + fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } + + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + if disabled { + self.disabledroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.disabledroomids.remove(room_id.as_bytes())?; + } + + Ok(()) + } } diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs index 97c29e5..406943e 100644 --- a/src/database/key_value/rooms/mod.rs +++ b/src/database/key_value/rooms/mod.rs @@ -15,8 +15,6 @@ mod state_compressor; mod timeline; mod user; -use std::sync::Arc; - use crate::{database::KeyValueDatabase, service}; -impl service::rooms::Data for Arc<KeyValueDatabase> {} +impl service::rooms::Data for KeyValueDatabase {} diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index b1ae816..aa97544 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{EventId, signatures::CanonicalJsonObject}; use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result}; -impl service::rooms::outlier::Data for Arc<KeyValueDatabase> { +impl service::rooms::outlier::Data for KeyValueDatabase { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { self.eventid_outlierpdu .get(event_id.as_bytes())? diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index f5e8f76..f3ac414 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -4,7 +4,7 @@ use ruma::{RoomId, EventId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::rooms::pdu_metadata::Data for Arc<KeyValueDatabase> { +impl service::rooms::pdu_metadata::Data for KeyValueDatabase { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { for prev in event_ids { let mut key = room_id.as_bytes().to_vec(); diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 7b8d278..dfbdbc6 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -1,10 +1,10 @@ -use std::{mem::size_of, sync::Arc}; +use std::mem::size_of; use ruma::RoomId; use crate::{service, database::KeyValueDatabase, utils, Result, services}; -impl service::rooms::search::Data for Arc<KeyValueDatabase> { +impl service::rooms::search::Data for KeyValueDatabase { fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> { let mut batch = message_body .split_terminator(|c: char| !c.is_alphanumeric()) diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 9a302b5..ecd12da 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -1,6 +1,227 @@ use std::sync::Arc; -use crate::{database::KeyValueDatabase, service}; +use ruma::{EventId, events::StateEventType, RoomId}; -impl service::rooms::short::Data for Arc<KeyValueDatabase> { +use crate::{Result, database::KeyValueDatabase, service, utils, Error, services}; + +impl service::rooms::short::Data for KeyValueDatabase { + fn get_or_create_shorteventid( + &self, + event_id: &EventId, + ) -> Result<u64> { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { + return Ok(*short); + } + + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => utils::u64_from_bytes(&shorteventid) + .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, + None => { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid + .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid + .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + } + }; + + self.eventidshort_cache + .lock() + .unwrap() + .insert(event_id.to_owned(), short); + + Ok(short) + } + + fn get_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<Option<u64>> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = self + .statekey_shortstatekey + .get(&statekey)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; + + if let Some(s) = short { + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), s); + } + + Ok(short) + } + + fn get_or_create_shortstatekey( + &self, + event_type: &StateEventType, + state_key: &str, + ) -> Result<u64> { + if let Some(short) = self + .statekeyshort_cache + .lock() + .unwrap() + .get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } + + let mut statekey = event_type.to_string().as_bytes().to_vec(); + statekey.push(0xff); + statekey.extend_from_slice(state_key.as_bytes()); + + let short = match self.statekey_shortstatekey.get(&statekey)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = services().globals.next_count()?; + self.statekey_shortstatekey + .insert(&statekey, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey + .insert(&shortstatekey.to_be_bytes(), &statekey)?; + shortstatekey + } + }; + + self.statekeyshort_cache + .lock() + .unwrap() + .insert((event_type.clone(), state_key.to_owned()), short); + + Ok(short) + } + + fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> { + if let Some(id) = self + .shorteventid_cache + .lock() + .unwrap() + .get_mut(&shorteventid) + { + return Ok(Arc::clone(id)); + } + + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + + let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + + self.shorteventid_cache + .lock() + .unwrap() + .insert(shorteventid, Arc::clone(&event_id)); + + Ok(event_id) + } + + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + if let Some(id) = self + .shortstatekey_cache + .lock() + .unwrap() + .get_mut(&shortstatekey) + { + return Ok(id.clone()); + } + + let bytes = self + .shortstatekey_statekey + .get(&shortstatekey.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + + let mut parts = bytes.splitn(2, |&b| b == 0xff); + let eventtype_bytes = parts.next().expect("split always returns one entry"); + let statekey_bytes = parts + .next() + .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + + let event_type = + StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| { + Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.") + })?) + .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?; + + let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { + Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") + })?; + + let result = (event_type, state_key); + + self.shortstatekey_cache + .lock() + .unwrap() + .insert(shortstatekey, result.clone()); + + Ok(result) + } + + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash( + &self, + state_hash: &[u8], + ) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = services().globals.next_count()?; + self.statehash_shortstatehash + .insert(state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + } + }) + } + + fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) + }) + .transpose() + } + + fn get_or_create_shortroomid( + &self, + room_id: &RoomId, + ) -> Result<u64> { + Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { + Some(short) => utils::u64_from_bytes(&short) + .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, + None => { + let short = services().globals.next_count()?; + self.roomid_shortroomid + .insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + } + }) + } } diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index 527c240..b2822b3 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -6,7 +6,7 @@ use std::fmt::Debug; use crate::{service, database::KeyValueDatabase, utils, Error, Result}; -impl service::rooms::state::Data for Arc<KeyValueDatabase> { +impl service::rooms::state::Data for KeyValueDatabase { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { self.roomid_shortstatehash .get(room_id.as_bytes())? diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index 9af45db..4d5bd4a 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -5,7 +5,7 @@ use async_trait::async_trait; use ruma::{EventId, events::StateEventType, RoomId}; #[async_trait] -impl service::rooms::state_accessor::Data for Arc<KeyValueDatabase> { +impl service::rooms::state_accessor::Data for KeyValueDatabase { async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { let full_state = services().rooms.state_compressor .load_shortstatehash_info(shortstatehash)? diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index bdb8cf8..5f05485 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw}; use crate::{service, database::KeyValueDatabase, services, Result}; -impl service::rooms::state_cache::Data for Arc<KeyValueDatabase> { +impl service::rooms::state_cache::Data for KeyValueDatabase { fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index e1c0280..aee1890 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; +use std::{collections::HashSet, mem::size_of}; use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result}; -impl service::rooms::state_compressor::Data for Arc<KeyValueDatabase> { +impl service::rooms::state_compressor::Data for KeyValueDatabase { fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> { let value = self .shortstatehash_statediff diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 2d334b9..0b7286b 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -5,7 +5,27 @@ use tracing::error; use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services}; -impl service::rooms::timeline::Data for Arc<KeyValueDatabase> { +impl service::rooms::timeline::Data for KeyValueDatabase { + fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> { + let prefix = services().rooms.short + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + // Look for PDUs in that room. + self.pduid_pdu + .iter_from(&prefix, false) + .filter(|(k, _)| k.starts_with(&prefix)) + .map(|(_, pdu)| { + serde_json::from_slice(&pdu) + .map_err(|_| Error::bad_database("Invalid first PDU in db.")) + .map(Arc::new) + }) + .next() + .transpose() + } + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { match self .lasttimelinecount_cache diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index 4d20b00..3759bda 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, RoomId}; use crate::{service, database::KeyValueDatabase, utils, Error, Result, services}; -impl service::rooms::user::Data for Arc<KeyValueDatabase> { +impl service::rooms::user::Data for KeyValueDatabase { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { let mut userroom_id = user_id.as_bytes().to_vec(); userroom_id.push(0xff); @@ -104,13 +102,13 @@ impl service::rooms::user::Data for Arc<KeyValueDatabase> { }); // We use the default compare function because keys are sorted correctly (not reversed) - Ok(utils::common_elements(iterators, Ord::cmp) + Ok(Box::new(Box::new(utils::common_elements(iterators, Ord::cmp) .expect("users is not empty") .map(|bytes| { RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| { Error::bad_database("Invalid RoomId bytes in userroomid_joined") })?) .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - })) + })))) } } diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index 7fa6908..a63b3c5 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, TransactionId}; use crate::{service, database::KeyValueDatabase, Result}; -impl service::transaction_ids::Data for Arc<KeyValueDatabase> { +impl service::transaction_ids::Data for KeyValueDatabase { fn add_txnid( &self, user_id: &UserId, diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 8752e55..cf242de 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,10 +1,8 @@ -use std::sync::Arc; - use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}}; use crate::{database::KeyValueDatabase, service, Error, Result}; -impl service::uiaa::Data for Arc<KeyValueDatabase> { +impl service::uiaa::Data for KeyValueDatabase { fn set_uiaa_request( &self, user_id: &UserId, diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 1ac85b3..55a518d 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,11 +1,11 @@ -use std::{mem::size_of, collections::BTreeMap, sync::Arc}; +use std::{mem::size_of, collections::BTreeMap}; use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt}; use tracing::warn; use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result}; -impl service::users::Data for Arc<KeyValueDatabase> { +impl service::users::Data for KeyValueDatabase { /// Check if a user has an account on this homeserver. fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) @@ -113,7 +113,7 @@ impl service::users::Data for Arc<KeyValueDatabase> { /// Hash and set the user's password to the Argon2 hash fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { if let Some(password) = password { - if let Ok(hash) = utils::calculate_hash(password) { + if let Ok(hash) = utils::calculate_password_hash(password) { self.userid_password .insert(user_id.as_bytes(), hash.as_bytes())?; Ok(()) |