summaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
authorTimo Kösters <timo@koesters.xyz>2022-10-05 18:36:12 +0200
committerNyaaori <+@nyaaori.cat>2022-10-10 14:02:00 +0200
commit44fe6d1554eaa0a15314686974ab01f48c836588 (patch)
tree742d3e844c32acc6fa1c6616d0ff440aff8a6e6c /src/database
parentcff52d7ebb5066f3d8e513488b84a431c0093e65 (diff)
downloadconduit-44fe6d1554eaa0a15314686974ab01f48c836588.zip
127 errors left
Diffstat (limited to 'src/database')
-rw-r--r--src/database/key_value/account_data.rs4
-rw-r--r--src/database/key_value/appservice.rs2
-rw-r--r--src/database/key_value/globals.rs4
-rw-r--r--src/database/key_value/key_backups.rs4
-rw-r--r--src/database/key_value/media.rs4
-rw-r--r--src/database/key_value/pusher.rs4
-rw-r--r--src/database/key_value/rooms/alias.rs4
-rw-r--r--src/database/key_value/rooms/auth_chain.rs2
-rw-r--r--src/database/key_value/rooms/directory.rs4
-rw-r--r--src/database/key_value/rooms/edus/mod.rs4
-rw-r--r--src/database/key_value/rooms/edus/presence.rs4
-rw-r--r--src/database/key_value/rooms/edus/read_receipt.rs4
-rw-r--r--src/database/key_value/rooms/edus/typing.rs4
-rw-r--r--src/database/key_value/rooms/lazy_load.rs4
-rw-r--r--src/database/key_value/rooms/metadata.rs18
-rw-r--r--src/database/key_value/rooms/mod.rs4
-rw-r--r--src/database/key_value/rooms/outlier.rs4
-rw-r--r--src/database/key_value/rooms/pdu_metadata.rs2
-rw-r--r--src/database/key_value/rooms/search.rs4
-rw-r--r--src/database/key_value/rooms/short.rs225
-rw-r--r--src/database/key_value/rooms/state.rs2
-rw-r--r--src/database/key_value/rooms/state_accessor.rs2
-rw-r--r--src/database/key_value/rooms/state_cache.rs4
-rw-r--r--src/database/key_value/rooms/state_compressor.rs4
-rw-r--r--src/database/key_value/rooms/timeline.rs22
-rw-r--r--src/database/key_value/rooms/user.rs8
-rw-r--r--src/database/key_value/transaction_ids.rs4
-rw-r--r--src/database/key_value/uiaa.rs4
-rw-r--r--src/database/key_value/users.rs6
-rw-r--r--src/database/mod.rs24
30 files changed, 308 insertions, 81 deletions
diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs
index f0325d2..5674ac0 100644
--- a/src/database/key_value/account_data.rs
+++ b/src/database/key_value/account_data.rs
@@ -1,11 +1,11 @@
-use std::{collections::HashMap, sync::Arc};
+use std::collections::HashMap;
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId};
use serde::{Serialize, de::DeserializeOwned};
use crate::{Result, database::KeyValueDatabase, service, Error, utils, services};
-impl service::account_data::Data for Arc<KeyValueDatabase> {
+impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs
index ee6ae20..f427ba7 100644
--- a/src/database/key_value/appservice.rs
+++ b/src/database/key_value/appservice.rs
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::appservice::Data for KeyValueDatabase {
diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs
index 8711920..199cbf6 100644
--- a/src/database/key_value/globals.rs
+++ b/src/database/key_value/globals.rs
@@ -1,4 +1,4 @@
-use std::{collections::BTreeMap, sync::Arc};
+use std::collections::BTreeMap;
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
@@ -9,7 +9,7 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils, services}
pub const COUNTER: &[u8] = b"c";
#[async_trait]
-impl service::globals::Data for Arc<KeyValueDatabase> {
+impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs
index c59ed36..8171451 100644
--- a/src/database/key_value/key_backups.rs
+++ b/src/database/key_value/key_backups.rs
@@ -1,10 +1,10 @@
-use std::{collections::BTreeMap, sync::Arc};
+use std::collections::BTreeMap;
use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId};
use crate::{Result, service, database::KeyValueDatabase, services, Error, utils};
-impl service::key_backups::Data for Arc<KeyValueDatabase> {
+impl service::key_backups::Data for KeyValueDatabase {
fn create_backup(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs
index 1726755..f024487 100644
--- a/src/database/key_value/media.rs
+++ b/src/database/key_value/media.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::api::client::error::ErrorKind;
use crate::{database::KeyValueDatabase, service, Error, utils, Result};
-impl service::media::Data for Arc<KeyValueDatabase> {
+impl service::media::Data for KeyValueDatabase {
fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xff);
diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs
index 85d1d86..b05e47b 100644
--- a/src/database/key_value/pusher.rs
+++ b/src/database/key_value/pusher.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, api::client::push::{set_pusher, get_pushers}};
use crate::{service, database::KeyValueDatabase, Error, Result};
-impl service::pusher::Data for Arc<KeyValueDatabase> {
+impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
let mut key = sender.as_bytes().to_vec();
key.push(0xff);
diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs
index 437902d..0aa8dd4 100644
--- a/src/database/key_value/rooms/alias.rs
+++ b/src/database/key_value/rooms/alias.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
-impl service::rooms::alias::Data for Arc<KeyValueDatabase> {
+impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(
&self,
alias: &RoomAliasId,
diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs
index 2dffb04..49d3956 100644
--- a/src/database/key_value/rooms/auth_chain.rs
+++ b/src/database/key_value/rooms/auth_chain.rs
@@ -2,7 +2,7 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{service, database::KeyValueDatabase, Result, utils};
-impl service::rooms::auth_chain::Data for Arc<KeyValueDatabase> {
+impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs
index 864e75e..727004e 100644
--- a/src/database/key_value/rooms/directory.rs
+++ b/src/database/key_value/rooms/directory.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
-impl service::rooms::directory::Data for Arc<KeyValueDatabase> {
+impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}
diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs
index 03e4219..b5007f8 100644
--- a/src/database/key_value/rooms/edus/mod.rs
+++ b/src/database/key_value/rooms/edus/mod.rs
@@ -2,8 +2,6 @@ mod presence;
mod typing;
mod read_receipt;
-use std::sync::Arc;
-
use crate::{service, database::KeyValueDatabase};
-impl service::rooms::edus::Data for Arc<KeyValueDatabase> {}
+impl service::rooms::edus::Data for KeyValueDatabase {}
diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs
index 5aeb147..1477c28 100644
--- a/src/database/key_value/rooms/edus/presence.rs
+++ b/src/database/key_value/rooms/edus/presence.rs
@@ -1,10 +1,10 @@
-use std::{collections::HashMap, sync::Arc};
+use std::collections::HashMap;
use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
-impl service::rooms::edus::presence::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn update_presence(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs
index 7fcb8ac..a12e265 100644
--- a/src/database/key_value/rooms/edus/read_receipt.rs
+++ b/src/database/key_value/rooms/edus/read_receipt.rs
@@ -1,10 +1,10 @@
-use std::{mem, sync::Arc};
+use std::mem;
use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
-impl service::rooms::edus::read_receipt::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs
index 7f3526d..b7d3596 100644
--- a/src/database/key_value/rooms/edus/typing.rs
+++ b/src/database/key_value/rooms/edus/typing.rs
@@ -1,10 +1,10 @@
-use std::{collections::HashSet, sync::Arc};
+use std::collections::HashSet;
use ruma::{UserId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
-impl service::rooms::edus::typing::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::typing::Data for KeyValueDatabase {
fn typing_add(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs
index b16657a..133e1d0 100644
--- a/src/database/key_value/rooms/lazy_load.rs
+++ b/src/database/key_value/rooms/lazy_load.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, RoomId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::rooms::lazy_loading::Data for Arc<KeyValueDatabase> {
+impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs
index 560beb9..72f6251 100644
--- a/src/database/key_value/rooms/metadata.rs
+++ b/src/database/key_value/rooms/metadata.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, Result, services};
-impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
+impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
@@ -19,4 +17,18 @@ impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
+
+ fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
+ Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
+ }
+
+ fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
+ if disabled {
+ self.disabledroomids.insert(room_id.as_bytes(), &[])?;
+ } else {
+ self.disabledroomids.remove(room_id.as_bytes())?;
+ }
+
+ Ok(())
+ }
}
diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs
index 97c29e5..406943e 100644
--- a/src/database/key_value/rooms/mod.rs
+++ b/src/database/key_value/rooms/mod.rs
@@ -15,8 +15,6 @@ mod state_compressor;
mod timeline;
mod user;
-use std::sync::Arc;
-
use crate::{database::KeyValueDatabase, service};
-impl service::rooms::Data for Arc<KeyValueDatabase> {}
+impl service::rooms::Data for KeyValueDatabase {}
diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs
index b1ae816..aa97544 100644
--- a/src/database/key_value/rooms/outlier.rs
+++ b/src/database/key_value/rooms/outlier.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result};
-impl service::rooms::outlier::Data for Arc<KeyValueDatabase> {
+impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs
index f5e8f76..f3ac414 100644
--- a/src/database/key_value/rooms/pdu_metadata.rs
+++ b/src/database/key_value/rooms/pdu_metadata.rs
@@ -4,7 +4,7 @@ use ruma::{RoomId, EventId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::rooms::pdu_metadata::Data for Arc<KeyValueDatabase> {
+impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs
index 7b8d278..dfbdbc6 100644
--- a/src/database/key_value/rooms/search.rs
+++ b/src/database/key_value/rooms/search.rs
@@ -1,10 +1,10 @@
-use std::{mem::size_of, sync::Arc};
+use std::mem::size_of;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Result, services};
-impl service::rooms::search::Data for Arc<KeyValueDatabase> {
+impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs
index 9a302b5..ecd12da 100644
--- a/src/database/key_value/rooms/short.rs
+++ b/src/database/key_value/rooms/short.rs
@@ -1,6 +1,227 @@
use std::sync::Arc;
-use crate::{database::KeyValueDatabase, service};
+use ruma::{EventId, events::StateEventType, RoomId};
-impl service::rooms::short::Data for Arc<KeyValueDatabase> {
+use crate::{Result, database::KeyValueDatabase, service, utils, Error, services};
+
+impl service::rooms::short::Data for KeyValueDatabase {
+ fn get_or_create_shorteventid(
+ &self,
+ event_id: &EventId,
+ ) -> Result<u64> {
+ if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
+ return Ok(*short);
+ }
+
+ let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
+ Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
+ .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
+ None => {
+ let shorteventid = services().globals.next_count()?;
+ self.eventid_shorteventid
+ .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
+ self.shorteventid_eventid
+ .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
+ shorteventid
+ }
+ };
+
+ self.eventidshort_cache
+ .lock()
+ .unwrap()
+ .insert(event_id.to_owned(), short);
+
+ Ok(short)
+ }
+
+ fn get_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<Option<u64>> {
+ if let Some(short) = self
+ .statekeyshort_cache
+ .lock()
+ .unwrap()
+ .get_mut(&(event_type.clone(), state_key.to_owned()))
+ {
+ return Ok(Some(*short));
+ }
+
+ let mut statekey = event_type.to_string().as_bytes().to_vec();
+ statekey.push(0xff);
+ statekey.extend_from_slice(state_key.as_bytes());
+
+ let short = self
+ .statekey_shortstatekey
+ .get(&statekey)?
+ .map(|shortstatekey| {
+ utils::u64_from_bytes(&shortstatekey)
+ .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
+ })
+ .transpose()?;
+
+ if let Some(s) = short {
+ self.statekeyshort_cache
+ .lock()
+ .unwrap()
+ .insert((event_type.clone(), state_key.to_owned()), s);
+ }
+
+ Ok(short)
+ }
+
+ fn get_or_create_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<u64> {
+ if let Some(short) = self
+ .statekeyshort_cache
+ .lock()
+ .unwrap()
+ .get_mut(&(event_type.clone(), state_key.to_owned()))
+ {
+ return Ok(*short);
+ }
+
+ let mut statekey = event_type.to_string().as_bytes().to_vec();
+ statekey.push(0xff);
+ statekey.extend_from_slice(state_key.as_bytes());
+
+ let short = match self.statekey_shortstatekey.get(&statekey)? {
+ Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
+ .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
+ None => {
+ let shortstatekey = services().globals.next_count()?;
+ self.statekey_shortstatekey
+ .insert(&statekey, &shortstatekey.to_be_bytes())?;
+ self.shortstatekey_statekey
+ .insert(&shortstatekey.to_be_bytes(), &statekey)?;
+ shortstatekey
+ }
+ };
+
+ self.statekeyshort_cache
+ .lock()
+ .unwrap()
+ .insert((event_type.clone(), state_key.to_owned()), short);
+
+ Ok(short)
+ }
+
+ fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
+ if let Some(id) = self
+ .shorteventid_cache
+ .lock()
+ .unwrap()
+ .get_mut(&shorteventid)
+ {
+ return Ok(Arc::clone(id));
+ }
+
+ let bytes = self
+ .shorteventid_eventid
+ .get(&shorteventid.to_be_bytes())?
+ .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
+
+ let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
+ Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
+ })?)
+ .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
+
+ self.shorteventid_cache
+ .lock()
+ .unwrap()
+ .insert(shorteventid, Arc::clone(&event_id));
+
+ Ok(event_id)
+ }
+
+ fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
+ if let Some(id) = self
+ .shortstatekey_cache
+ .lock()
+ .unwrap()
+ .get_mut(&shortstatekey)
+ {
+ return Ok(id.clone());
+ }
+
+ let bytes = self
+ .shortstatekey_statekey
+ .get(&shortstatekey.to_be_bytes())?
+ .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
+
+ let mut parts = bytes.splitn(2, |&b| b == 0xff);
+ let eventtype_bytes = parts.next().expect("split always returns one entry");
+ let statekey_bytes = parts
+ .next()
+ .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
+
+ let event_type =
+ StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
+ Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
+ })?)
+ .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
+
+ let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
+ Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
+ })?;
+
+ let result = (event_type, state_key);
+
+ self.shortstatekey_cache
+ .lock()
+ .unwrap()
+ .insert(shortstatekey, result.clone());
+
+ Ok(result)
+ }
+
+ /// Returns (shortstatehash, already_existed)
+ fn get_or_create_shortstatehash(
+ &self,
+ state_hash: &[u8],
+ ) -> Result<(u64, bool)> {
+ Ok(match self.statehash_shortstatehash.get(state_hash)? {
+ Some(shortstatehash) => (
+ utils::u64_from_bytes(&shortstatehash)
+ .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
+ true,
+ ),
+ None => {
+ let shortstatehash = services().globals.next_count()?;
+ self.statehash_shortstatehash
+ .insert(state_hash, &shortstatehash.to_be_bytes())?;
+ (shortstatehash, false)
+ }
+ })
+ }
+
+ fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
+ self.roomid_shortroomid
+ .get(room_id.as_bytes())?
+ .map(|bytes| {
+ utils::u64_from_bytes(&bytes)
+ .map_err(|_| Error::bad_database("Invalid shortroomid in db."))
+ })
+ .transpose()
+ }
+
+ fn get_or_create_shortroomid(
+ &self,
+ room_id: &RoomId,
+ ) -> Result<u64> {
+ Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
+ Some(short) => utils::u64_from_bytes(&short)
+ .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
+ None => {
+ let short = services().globals.next_count()?;
+ self.roomid_shortroomid
+ .insert(room_id.as_bytes(), &short.to_be_bytes())?;
+ short
+ }
+ })
+ }
}
diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs
index 527c240..b2822b3 100644
--- a/src/database/key_value/rooms/state.rs
+++ b/src/database/key_value/rooms/state.rs
@@ -6,7 +6,7 @@ use std::fmt::Debug;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
-impl service::rooms::state::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs
index 9af45db..4d5bd4a 100644
--- a/src/database/key_value/rooms/state_accessor.rs
+++ b/src/database/key_value/rooms/state_accessor.rs
@@ -5,7 +5,7 @@ use async_trait::async_trait;
use ruma::{EventId, events::StateEventType, RoomId};
#[async_trait]
-impl service::rooms::state_accessor::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs
index bdb8cf8..5f05485 100644
--- a/src/database/key_value/rooms/state_cache.rs
+++ b/src/database/key_value/rooms/state_cache.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw};
use crate::{service, database::KeyValueDatabase, services, Result};
-impl service::rooms::state_cache::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs
index e1c0280..aee1890 100644
--- a/src/database/key_value/rooms/state_compressor.rs
+++ b/src/database/key_value/rooms/state_compressor.rs
@@ -1,8 +1,8 @@
-use std::{collections::HashSet, mem::size_of, sync::Arc};
+use std::{collections::HashSet, mem::size_of};
use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result};
-impl service::rooms::state_compressor::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs
index 2d334b9..0b7286b 100644
--- a/src/database/key_value/rooms/timeline.rs
+++ b/src/database/key_value/rooms/timeline.rs
@@ -5,7 +5,27 @@ use tracing::error;
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services};
-impl service::rooms::timeline::Data for Arc<KeyValueDatabase> {
+impl service::rooms::timeline::Data for KeyValueDatabase {
+ fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
+ let prefix = services().rooms.short
+ .get_shortroomid(room_id)?
+ .expect("room exists")
+ .to_be_bytes()
+ .to_vec();
+
+ // Look for PDUs in that room.
+ self.pduid_pdu
+ .iter_from(&prefix, false)
+ .filter(|(k, _)| k.starts_with(&prefix))
+ .map(|(_, pdu)| {
+ serde_json::from_slice(&pdu)
+ .map_err(|_| Error::bad_database("Invalid first PDU in db."))
+ .map(Arc::new)
+ })
+ .next()
+ .transpose()
+ }
+
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
match self
.lasttimelinecount_cache
diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs
index 4d20b00..3759bda 100644
--- a/src/database/key_value/rooms/user.rs
+++ b/src/database/key_value/rooms/user.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, RoomId};
use crate::{service, database::KeyValueDatabase, utils, Error, Result, services};
-impl service::rooms::user::Data for Arc<KeyValueDatabase> {
+impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
@@ -104,13 +102,13 @@ impl service::rooms::user::Data for Arc<KeyValueDatabase> {
});
// We use the default compare function because keys are sorted correctly (not reversed)
- Ok(utils::common_elements(iterators, Ord::cmp)
+ Ok(Box::new(Box::new(utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
- }))
+ }))))
}
}
diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs
index 7fa6908..a63b3c5 100644
--- a/src/database/key_value/transaction_ids.rs
+++ b/src/database/key_value/transaction_ids.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, TransactionId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::transaction_ids::Data for Arc<KeyValueDatabase> {
+impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs
index 8752e55..cf242de 100644
--- a/src/database/key_value/uiaa.rs
+++ b/src/database/key_value/uiaa.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}};
use crate::{database::KeyValueDatabase, service, Error, Result};
-impl service::uiaa::Data for Arc<KeyValueDatabase> {
+impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs
index 1ac85b3..55a518d 100644
--- a/src/database/key_value/users.rs
+++ b/src/database/key_value/users.rs
@@ -1,11 +1,11 @@
-use std::{mem::size_of, collections::BTreeMap, sync::Arc};
+use std::{mem::size_of, collections::BTreeMap};
use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt};
use tracing::warn;
use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result};
-impl service::users::Data for Arc<KeyValueDatabase> {
+impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
@@ -113,7 +113,7 @@ impl service::users::Data for Arc<KeyValueDatabase> {
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
- if let Ok(hash) = utils::calculate_hash(password) {
+ if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 35922f0..6868467 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -238,8 +238,8 @@ impl KeyValueDatabase {
}
/// Load an existing database or create a new one.
- pub async fn load_or_create(config: &Config) -> Result<()> {
- Self::check_db_setup(config)?;
+ pub async fn load_or_create(config: Config) -> Result<()> {
+ Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
std::fs::create_dir_all(&config.database_path)
@@ -251,19 +251,19 @@ impl KeyValueDatabase {
#[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "sqlite")]
- Arc::new(Arc::<abstraction::sqlite::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
"rocksdb" => {
#[cfg(not(feature = "rocksdb"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "rocksdb")]
- Arc::new(Arc::<abstraction::rocksdb::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
"persy" => {
#[cfg(not(feature = "persy"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "persy")]
- Arc::new(Arc::<abstraction::persy::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::persy::Engine>::open(&config)?)
}
_ => {
return Err(Error::BadConfig("Database backend not found."));
@@ -402,7 +402,7 @@ impl KeyValueDatabase {
});
- let services_raw = Box::new(Services::build(Arc::clone(&db)));
+ let services_raw = Box::new(Services::build(Arc::clone(&db), config)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
@@ -825,7 +825,7 @@ impl KeyValueDatabase {
info!(
"Loaded {} database with version {}",
- config.database_backend, latest_database_version
+ services().globals.config.database_backend, latest_database_version
);
} else {
services()
@@ -837,7 +837,7 @@ impl KeyValueDatabase {
warn!(
"Created new {} database with version {}",
- config.database_backend, latest_database_version
+ services().globals.config.database_backend, latest_database_version
);
}
@@ -866,7 +866,7 @@ impl KeyValueDatabase {
.sending
.start_handler(sending_receiver);
- Self::start_cleanup_task(config).await;
+ Self::start_cleanup_task().await;
Ok(())
}
@@ -888,8 +888,8 @@ impl KeyValueDatabase {
res
}
- #[tracing::instrument(skip(config))]
- pub async fn start_cleanup_task(config: &Config) {
+ #[tracing::instrument]
+ pub async fn start_cleanup_task() {
use tokio::time::interval;
#[cfg(unix)]
@@ -898,7 +898,7 @@ impl KeyValueDatabase {
use std::time::{Duration, Instant};
- let timer_interval = Duration::from_secs(config.cleanup_second_interval as u64);
+ let timer_interval = Duration::from_secs(services().globals.config.cleanup_second_interval as u64);
tokio::spawn(async move {
let mut i = interval(timer_interval);