summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTimo Kösters <timo@koesters.xyz>2022-10-05 18:36:12 +0200
committerNyaaori <+@nyaaori.cat>2022-10-10 14:02:00 +0200
commit44fe6d1554eaa0a15314686974ab01f48c836588 (patch)
tree742d3e844c32acc6fa1c6616d0ff440aff8a6e6c
parentcff52d7ebb5066f3d8e513488b84a431c0093e65 (diff)
downloadconduit-44fe6d1554eaa0a15314686974ab01f48c836588.zip
127 errors left
-rw-r--r--src/api/client_server/membership.rs2
-rw-r--r--src/api/server_server.rs133
-rw-r--r--src/database/key_value/account_data.rs4
-rw-r--r--src/database/key_value/appservice.rs2
-rw-r--r--src/database/key_value/globals.rs4
-rw-r--r--src/database/key_value/key_backups.rs4
-rw-r--r--src/database/key_value/media.rs4
-rw-r--r--src/database/key_value/pusher.rs4
-rw-r--r--src/database/key_value/rooms/alias.rs4
-rw-r--r--src/database/key_value/rooms/auth_chain.rs2
-rw-r--r--src/database/key_value/rooms/directory.rs4
-rw-r--r--src/database/key_value/rooms/edus/mod.rs4
-rw-r--r--src/database/key_value/rooms/edus/presence.rs4
-rw-r--r--src/database/key_value/rooms/edus/read_receipt.rs4
-rw-r--r--src/database/key_value/rooms/edus/typing.rs4
-rw-r--r--src/database/key_value/rooms/lazy_load.rs4
-rw-r--r--src/database/key_value/rooms/metadata.rs18
-rw-r--r--src/database/key_value/rooms/mod.rs4
-rw-r--r--src/database/key_value/rooms/outlier.rs4
-rw-r--r--src/database/key_value/rooms/pdu_metadata.rs2
-rw-r--r--src/database/key_value/rooms/search.rs4
-rw-r--r--src/database/key_value/rooms/short.rs225
-rw-r--r--src/database/key_value/rooms/state.rs2
-rw-r--r--src/database/key_value/rooms/state_accessor.rs2
-rw-r--r--src/database/key_value/rooms/state_cache.rs4
-rw-r--r--src/database/key_value/rooms/state_compressor.rs4
-rw-r--r--src/database/key_value/rooms/timeline.rs22
-rw-r--r--src/database/key_value/rooms/user.rs8
-rw-r--r--src/database/key_value/transaction_ids.rs4
-rw-r--r--src/database/key_value/uiaa.rs4
-rw-r--r--src/database/key_value/users.rs6
-rw-r--r--src/database/mod.rs24
-rw-r--r--src/service/account_data/mod.rs2
-rw-r--r--src/service/admin/mod.rs12
-rw-r--r--src/service/globals/mod.rs4
-rw-r--r--src/service/key_backups/mod.rs2
-rw-r--r--src/service/media/mod.rs2
-rw-r--r--src/service/mod.rs86
-rw-r--r--src/service/rooms/alias/mod.rs4
-rw-r--r--src/service/rooms/auth_chain/mod.rs135
-rw-r--r--src/service/rooms/directory/mod.rs4
-rw-r--r--src/service/rooms/edus/presence/mod.rs4
-rw-r--r--src/service/rooms/edus/read_receipt/mod.rs4
-rw-r--r--src/service/rooms/edus/typing/mod.rs4
-rw-r--r--src/service/rooms/event_handler/mod.rs84
-rw-r--r--src/service/rooms/lazy_loading/mod.rs4
-rw-r--r--src/service/rooms/metadata/data.rs2
-rw-r--r--src/service/rooms/metadata/mod.rs12
-rw-r--r--src/service/rooms/outlier/mod.rs4
-rw-r--r--src/service/rooms/pdu_metadata/mod.rs2
-rw-r--r--src/service/rooms/search/mod.rs4
-rw-r--r--src/service/rooms/short/data.rs38
-rw-r--r--src/service/rooms/short/mod.rs190
-rw-r--r--src/service/rooms/state/mod.rs99
-rw-r--r--src/service/rooms/state_accessor/mod.rs4
-rw-r--r--src/service/rooms/state_cache/mod.rs68
-rw-r--r--src/service/rooms/state_compressor/mod.rs6
-rw-r--r--src/service/rooms/timeline/data.rs1
-rw-r--r--src/service/rooms/timeline/mod.rs26
-rw-r--r--src/service/rooms/user/mod.rs4
-rw-r--r--src/service/sending/mod.rs8
-rw-r--r--src/service/transaction_ids/mod.rs4
-rw-r--r--src/service/uiaa/mod.rs4
-rw-r--r--src/service/users/mod.rs4
-rw-r--r--src/utils/mod.rs12
65 files changed, 810 insertions, 557 deletions
diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs
index 58ed040..f07f2ad 100644
--- a/src/api/client_server/membership.rs
+++ b/src/api/client_server/membership.rs
@@ -654,7 +654,7 @@ async fn join_room_by_id_helper(
// We set the room state after inserting the pdu, so that we never have a moment in time
// where events in the current room state do not exist
- services().rooms.state.set_room_state(room_id, shortstatehash)?;
+ services().rooms.state.set_room_state(room_id, shortstatehash, &state_lock)?;
let statehashid = services().rooms.state.append_to_state(&parsed_pdu)?;
} else {
diff --git a/src/api/server_server.rs b/src/api/server_server.rs
index 647f457..11f7ec3 100644
--- a/src/api/server_server.rs
+++ b/src/api/server_server.rs
@@ -857,131 +857,6 @@ pub async fn send_transaction_message_route(
Ok(send_transaction_message::v1::Response { pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.to_string()))).collect() })
}
-#[tracing::instrument(skip(starting_events))]
-pub(crate) async fn get_auth_chain<'a>(
- room_id: &RoomId,
- starting_events: Vec<Arc<EventId>>,
-) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
- const NUM_BUCKETS: usize = 50;
-
- let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
-
- let mut i = 0;
- for id in starting_events {
- let short = services().rooms.short.get_or_create_shorteventid(&id)?;
- let bucket_id = (short % NUM_BUCKETS as u64) as usize;
- buckets[bucket_id].insert((short, id.clone()));
- i += 1;
- if i % 100 == 0 {
- tokio::task::yield_now().await;
- }
- }
-
- let mut full_auth_chain = HashSet::new();
-
- let mut hits = 0;
- let mut misses = 0;
- for chunk in buckets {
- if chunk.is_empty() {
- continue;
- }
-
- let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
- if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
- hits += 1;
- full_auth_chain.extend(cached.iter().copied());
- continue;
- }
- misses += 1;
-
- let mut chunk_cache = HashSet::new();
- let mut hits2 = 0;
- let mut misses2 = 0;
- let mut i = 0;
- for (sevent_id, event_id) in chunk {
- if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
- hits2 += 1;
- chunk_cache.extend(cached.iter().copied());
- } else {
- misses2 += 1;
- let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?);
- services().rooms
- .auth_chain
- .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
- println!(
- "cache missed event {} with auth chain len {}",
- event_id,
- auth_chain.len()
- );
- chunk_cache.extend(auth_chain.iter());
-
- i += 1;
- if i % 100 == 0 {
- tokio::task::yield_now().await;
- }
- };
- }
- println!(
- "chunk missed with len {}, event hits2: {}, misses2: {}",
- chunk_cache.len(),
- hits2,
- misses2
- );
- let chunk_cache = Arc::new(chunk_cache);
- services().rooms
- .auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
- full_auth_chain.extend(chunk_cache.iter());
- }
-
- println!(
- "total: {}, chunk hits: {}, misses: {}",
- full_auth_chain.len(),
- hits,
- misses
- );
-
- Ok(full_auth_chain
- .into_iter()
- .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
-}
-
-#[tracing::instrument(skip(event_id))]
-fn get_auth_chain_inner(
- room_id: &RoomId,
- event_id: &EventId,
-) -> Result<HashSet<u64>> {
- let mut todo = vec![Arc::from(event_id)];
- let mut found = HashSet::new();
-
- while let Some(event_id) = todo.pop() {
- match services().rooms.timeline.get_pdu(&event_id) {
- Ok(Some(pdu)) => {
- if pdu.room_id != room_id {
- return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
- }
- for auth_event in &pdu.auth_events {
- let sauthevent = services()
- .rooms.short
- .get_or_create_shorteventid(auth_event)?;
-
- if !found.contains(&sauthevent) {
- found.insert(sauthevent);
- todo.push(auth_event.clone());
- }
- }
- }
- Ok(None) => {
- warn!("Could not find pdu mentioned in auth events: {}", event_id);
- }
- Err(e) => {
- warn!("Could not load event in auth chain: {} {}", event_id, e);
- }
- }
- }
-
- Ok(found)
-}
-
/// # `GET /_matrix/federation/v1/event/{eventId}`
///
/// Retrieves a single event from the server.
@@ -1135,7 +1010,7 @@ pub async fn get_event_authorization_route(
let room_id = <&RoomId>::try_from(room_id_str)
.map_err(|_| Error::bad_database("Invalid room id field in event in database"))?;
- let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?;
+ let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids
@@ -1190,7 +1065,7 @@ pub async fn get_room_state_route(
.collect();
let auth_chain_ids =
- get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
+ services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids
@@ -1246,7 +1121,7 @@ pub async fn get_room_state_ids_route(
.collect();
let auth_chain_ids =
- get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
+ services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?;
Ok(get_room_state_ids::v1::Response {
auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(),
@@ -1449,7 +1324,7 @@ async fn create_join_event(
drop(mutex_lock);
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
- let auth_chain_ids = get_auth_chain(
+ let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(
room_id,
state_ids.iter().map(|(_, id)| id.clone()).collect(),
)
diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs
index f0325d2..5674ac0 100644
--- a/src/database/key_value/account_data.rs
+++ b/src/database/key_value/account_data.rs
@@ -1,11 +1,11 @@
-use std::{collections::HashMap, sync::Arc};
+use std::collections::HashMap;
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}, events::{RoomAccountDataEventType, AnyEphemeralRoomEvent}, serde::Raw, RoomId};
use serde::{Serialize, de::DeserializeOwned};
use crate::{Result, database::KeyValueDatabase, service, Error, utils, services};
-impl service::account_data::Data for Arc<KeyValueDatabase> {
+impl service::account_data::Data for KeyValueDatabase {
/// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
fn update(
diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs
index ee6ae20..f427ba7 100644
--- a/src/database/key_value/appservice.rs
+++ b/src/database/key_value/appservice.rs
@@ -1,5 +1,3 @@
-use std::sync::Arc;
-
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
impl service::appservice::Data for KeyValueDatabase {
diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs
index 8711920..199cbf6 100644
--- a/src/database/key_value/globals.rs
+++ b/src/database/key_value/globals.rs
@@ -1,4 +1,4 @@
-use std::{collections::BTreeMap, sync::Arc};
+use std::collections::BTreeMap;
use async_trait::async_trait;
use futures_util::{stream::FuturesUnordered, StreamExt};
@@ -9,7 +9,7 @@ use crate::{Result, service, database::KeyValueDatabase, Error, utils, services}
pub const COUNTER: &[u8] = b"c";
#[async_trait]
-impl service::globals::Data for Arc<KeyValueDatabase> {
+impl service::globals::Data for KeyValueDatabase {
fn next_count(&self) -> Result<u64> {
utils::u64_from_bytes(&self.global.increment(COUNTER)?)
.map_err(|_| Error::bad_database("Count has invalid bytes."))
diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs
index c59ed36..8171451 100644
--- a/src/database/key_value/key_backups.rs
+++ b/src/database/key_value/key_backups.rs
@@ -1,10 +1,10 @@
-use std::{collections::BTreeMap, sync::Arc};
+use std::collections::BTreeMap;
use ruma::{UserId, serde::Raw, api::client::{backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, error::ErrorKind}, RoomId};
use crate::{Result, service, database::KeyValueDatabase, services, Error, utils};
-impl service::key_backups::Data for Arc<KeyValueDatabase> {
+impl service::key_backups::Data for KeyValueDatabase {
fn create_backup(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs
index 1726755..f024487 100644
--- a/src/database/key_value/media.rs
+++ b/src/database/key_value/media.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::api::client::error::ErrorKind;
use crate::{database::KeyValueDatabase, service, Error, utils, Result};
-impl service::media::Data for Arc<KeyValueDatabase> {
+impl service::media::Data for KeyValueDatabase {
fn create_file_metadata(&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>) -> Result<Vec<u8>> {
let mut key = mxc.as_bytes().to_vec();
key.push(0xff);
diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs
index 85d1d86..b05e47b 100644
--- a/src/database/key_value/pusher.rs
+++ b/src/database/key_value/pusher.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, api::client::push::{set_pusher, get_pushers}};
use crate::{service, database::KeyValueDatabase, Error, Result};
-impl service::pusher::Data for Arc<KeyValueDatabase> {
+impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
let mut key = sender.as_bytes().to_vec();
key.push(0xff);
diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs
index 437902d..0aa8dd4 100644
--- a/src/database/key_value/rooms/alias.rs
+++ b/src/database/key_value/rooms/alias.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
-impl service::rooms::alias::Data for Arc<KeyValueDatabase> {
+impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias(
&self,
alias: &RoomAliasId,
diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs
index 2dffb04..49d3956 100644
--- a/src/database/key_value/rooms/auth_chain.rs
+++ b/src/database/key_value/rooms/auth_chain.rs
@@ -2,7 +2,7 @@ use std::{collections::HashSet, mem::size_of, sync::Arc};
use crate::{service, database::KeyValueDatabase, Result, utils};
-impl service::rooms::auth_chain::Data for Arc<KeyValueDatabase> {
+impl service::rooms::auth_chain::Data for KeyValueDatabase {
fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<HashSet<u64>>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs
index 864e75e..727004e 100644
--- a/src/database/key_value/rooms/directory.rs
+++ b/src/database/key_value/rooms/directory.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
-impl service::rooms::directory::Data for Arc<KeyValueDatabase> {
+impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}
diff --git a/src/database/key_value/rooms/edus/mod.rs b/src/database/key_value/rooms/edus/mod.rs
index 03e4219..b5007f8 100644
--- a/src/database/key_value/rooms/edus/mod.rs
+++ b/src/database/key_value/rooms/edus/mod.rs
@@ -2,8 +2,6 @@ mod presence;
mod typing;
mod read_receipt;
-use std::sync::Arc;
-
use crate::{service, database::KeyValueDatabase};
-impl service::rooms::edus::Data for Arc<KeyValueDatabase> {}
+impl service::rooms::edus::Data for KeyValueDatabase {}
diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs
index 5aeb147..1477c28 100644
--- a/src/database/key_value/rooms/edus/presence.rs
+++ b/src/database/key_value/rooms/edus/presence.rs
@@ -1,10 +1,10 @@
-use std::{collections::HashMap, sync::Arc};
+use std::collections::HashMap;
use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt};
use crate::{service, database::KeyValueDatabase, utils, Error, services, Result};
-impl service::rooms::edus::presence::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn update_presence(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs
index 7fcb8ac..a12e265 100644
--- a/src/database/key_value/rooms/edus/read_receipt.rs
+++ b/src/database/key_value/rooms/edus/read_receipt.rs
@@ -1,10 +1,10 @@
-use std::{mem, sync::Arc};
+use std::mem;
use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
-impl service::rooms::edus::read_receipt::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs
index 7f3526d..b7d3596 100644
--- a/src/database/key_value/rooms/edus/typing.rs
+++ b/src/database/key_value/rooms/edus/typing.rs
@@ -1,10 +1,10 @@
-use std::{collections::HashSet, sync::Arc};
+use std::collections::HashSet;
use ruma::{UserId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, services, Result};
-impl service::rooms::edus::typing::Data for Arc<KeyValueDatabase> {
+impl service::rooms::edus::typing::Data for KeyValueDatabase {
fn typing_add(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs
index b16657a..133e1d0 100644
--- a/src/database/key_value/rooms/lazy_load.rs
+++ b/src/database/key_value/rooms/lazy_load.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, RoomId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::rooms::lazy_loading::Data for Arc<KeyValueDatabase> {
+impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs
index 560beb9..72f6251 100644
--- a/src/database/key_value/rooms/metadata.rs
+++ b/src/database/key_value/rooms/metadata.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, Result, services};
-impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
+impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match services().rooms.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
@@ -19,4 +17,18 @@ impl service::rooms::metadata::Data for Arc<KeyValueDatabase> {
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
+
+ fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
+ Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
+ }
+
+ fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
+ if disabled {
+ self.disabledroomids.insert(room_id.as_bytes(), &[])?;
+ } else {
+ self.disabledroomids.remove(room_id.as_bytes())?;
+ }
+
+ Ok(())
+ }
}
diff --git a/src/database/key_value/rooms/mod.rs b/src/database/key_value/rooms/mod.rs
index 97c29e5..406943e 100644
--- a/src/database/key_value/rooms/mod.rs
+++ b/src/database/key_value/rooms/mod.rs
@@ -15,8 +15,6 @@ mod state_compressor;
mod timeline;
mod user;
-use std::sync::Arc;
-
use crate::{database::KeyValueDatabase, service};
-impl service::rooms::Data for Arc<KeyValueDatabase> {}
+impl service::rooms::Data for KeyValueDatabase {}
diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs
index b1ae816..aa97544 100644
--- a/src/database/key_value/rooms/outlier.rs
+++ b/src/database/key_value/rooms/outlier.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{service, database::KeyValueDatabase, PduEvent, Error, Result};
-impl service::rooms::outlier::Data for Arc<KeyValueDatabase> {
+impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs
index f5e8f76..f3ac414 100644
--- a/src/database/key_value/rooms/pdu_metadata.rs
+++ b/src/database/key_value/rooms/pdu_metadata.rs
@@ -4,7 +4,7 @@ use ruma::{RoomId, EventId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::rooms::pdu_metadata::Data for Arc<KeyValueDatabase> {
+impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs
index 7b8d278..dfbdbc6 100644
--- a/src/database/key_value/rooms/search.rs
+++ b/src/database/key_value/rooms/search.rs
@@ -1,10 +1,10 @@
-use std::{mem::size_of, sync::Arc};
+use std::mem::size_of;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Result, services};
-impl service::rooms::search::Data for Arc<KeyValueDatabase> {
+impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, shortroomid: u64, pdu_id: &[u8], message_body: String) -> Result<()> {
let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric())
diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs
index 9a302b5..ecd12da 100644
--- a/src/database/key_value/rooms/short.rs
+++ b/src/database/key_value/rooms/short.rs
@@ -1,6 +1,227 @@
use std::sync::Arc;
-use crate::{database::KeyValueDatabase, service};
+use ruma::{EventId, events::StateEventType, RoomId};
-impl service::rooms::short::Data for Arc<KeyValueDatabase> {
+use crate::{Result, database::KeyValueDatabase, service, utils, Error, services};
+
+impl service::rooms::short::Data for KeyValueDatabase {
+ fn get_or_create_shorteventid(
+ &self,
+ event_id: &EventId,
+ ) -> Result<u64> {
+ if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
+ return Ok(*short);
+ }
+
+ let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
+ Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
+ .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
+ None => {
+ let shorteventid = services().globals.next_count()?;
+ self.eventid_shorteventid
+ .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
+ self.shorteventid_eventid
+ .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
+ shorteventid
+ }
+ };
+
+ self.eventidshort_cache
+ .lock()
+ .unwrap()
+ .insert(event_id.to_owned(), short);
+
+ Ok(short)
+ }
+
+ fn get_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<Option<u64>> {
+ if let Some(short) = self
+ .statekeyshort_cache
+ .lock()
+ .unwrap()
+ .get_mut(&(event_type.clone(), state_key.to_owned()))
+ {
+ return Ok(Some(*short));
+ }
+
+ let mut statekey = event_type.to_string().as_bytes().to_vec();
+ statekey.push(0xff);
+ statekey.extend_from_slice(state_key.as_bytes());
+
+ let short = self
+ .statekey_shortstatekey
+ .get(&statekey)?
+ .map(|shortstatekey| {
+ utils::u64_from_bytes(&shortstatekey)
+ .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
+ })
+ .transpose()?;
+
+ if let Some(s) = short {
+ self.statekeyshort_cache
+ .lock()
+ .unwrap()
+ .insert((event_type.clone(), state_key.to_owned()), s);
+ }
+
+ Ok(short)
+ }
+
+ fn get_or_create_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<u64> {
+ if let Some(short) = self
+ .statekeyshort_cache
+ .lock()
+ .unwrap()
+ .get_mut(&(event_type.clone(), state_key.to_owned()))
+ {
+ return Ok(*short);
+ }
+
+ let mut statekey = event_type.to_string().as_bytes().to_vec();
+ statekey.push(0xff);
+ statekey.extend_from_slice(state_key.as_bytes());
+
+ let short = match self.statekey_shortstatekey.get(&statekey)? {
+ Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
+ .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
+ None => {
+ let shortstatekey = services().globals.next_count()?;
+ self.statekey_shortstatekey
+ .insert(&statekey, &shortstatekey.to_be_bytes())?;
+ self.shortstatekey_statekey
+ .insert(&shortstatekey.to_be_bytes(), &statekey)?;
+ shortstatekey
+ }
+ };
+
+ self.statekeyshort_cache
+ .lock()
+ .unwrap()
+ .insert((event_type.clone(), state_key.to_owned()), short);
+
+ Ok(short)
+ }
+
+ fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
+ if let Some(id) = self
+ .shorteventid_cache
+ .lock()
+ .unwrap()
+ .get_mut(&shorteventid)
+ {
+ return Ok(Arc::clone(id));
+ }
+
+ let bytes = self
+ .shorteventid_eventid
+ .get(&shorteventid.to_be_bytes())?
+ .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
+
+ let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
+ Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
+ })?)
+ .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
+
+ self.shorteventid_cache
+ .lock()
+ .unwrap()
+ .insert(shorteventid, Arc::clone(&event_id));
+
+ Ok(event_id)
+ }
+
+ fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
+ if let Some(id) = self
+ .shortstatekey_cache
+ .lock()
+ .unwrap()
+ .get_mut(&shortstatekey)
+ {
+ return Ok(id.clone());
+ }
+
+ let bytes = self
+ .shortstatekey_statekey
+ .get(&shortstatekey.to_be_bytes())?
+ .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
+
+ let mut parts = bytes.splitn(2, |&b| b == 0xff);
+ let eventtype_bytes = parts.next().expect("split always returns one entry");
+ let statekey_bytes = parts
+ .next()
+ .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
+
+ let event_type =
+ StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
+ Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
+ })?)
+ .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
+
+ let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
+ Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
+ })?;
+
+ let result = (event_type, state_key);
+
+ self.shortstatekey_cache
+ .lock()
+ .unwrap()
+ .insert(shortstatekey, result.clone());
+
+ Ok(result)
+ }
+
+ /// Returns (shortstatehash, already_existed)
+ fn get_or_create_shortstatehash(
+ &self,
+ state_hash: &[u8],
+ ) -> Result<(u64, bool)> {
+ Ok(match self.statehash_shortstatehash.get(state_hash)? {
+ Some(shortstatehash) => (
+ utils::u64_from_bytes(&shortstatehash)
+ .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
+ true,
+ ),
+ None => {
+ let shortstatehash = services().globals.next_count()?;
+ self.statehash_shortstatehash
+ .insert(state_hash, &shortstatehash.to_be_bytes())?;
+ (shortstatehash, false)
+ }
+ })
+ }
+
+ fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
+ self.roomid_shortroomid
+ .get(room_id.as_bytes())?
+ .map(|bytes| {
+ utils::u64_from_bytes(&bytes)
+ .map_err(|_| Error::bad_database("Invalid shortroomid in db."))
+ })
+ .transpose()
+ }
+
+ fn get_or_create_shortroomid(
+ &self,
+ room_id: &RoomId,
+ ) -> Result<u64> {
+ Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
+ Some(short) => utils::u64_from_bytes(&short)
+ .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
+ None => {
+ let short = services().globals.next_count()?;
+ self.roomid_shortroomid
+ .insert(room_id.as_bytes(), &short.to_be_bytes())?;
+ short
+ }
+ })
+ }
}
diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs
index 527c240..b2822b3 100644
--- a/src/database/key_value/rooms/state.rs
+++ b/src/database/key_value/rooms/state.rs
@@ -6,7 +6,7 @@ use std::fmt::Debug;
use crate::{service, database::KeyValueDatabase, utils, Error, Result};
-impl service::rooms::state::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs
index 9af45db..4d5bd4a 100644
--- a/src/database/key_value/rooms/state_accessor.rs
+++ b/src/database/key_value/rooms/state_accessor.rs
@@ -5,7 +5,7 @@ use async_trait::async_trait;
use ruma::{EventId, events::StateEventType, RoomId};
#[async_trait]
-impl service::rooms::state_accessor::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = services().rooms.state_compressor
.load_shortstatehash_info(shortstatehash)?
diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs
index bdb8cf8..5f05485 100644
--- a/src/database/key_value/rooms/state_cache.rs
+++ b/src/database/key_value/rooms/state_cache.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw};
use crate::{service, database::KeyValueDatabase, services, Result};
-impl service::rooms::state_cache::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs
index e1c0280..aee1890 100644
--- a/src/database/key_value/rooms/state_compressor.rs
+++ b/src/database/key_value/rooms/state_compressor.rs
@@ -1,8 +1,8 @@
-use std::{collections::HashSet, mem::size_of, sync::Arc};
+use std::{collections::HashSet, mem::size_of};
use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils, Result};
-impl service::rooms::state_compressor::Data for Arc<KeyValueDatabase> {
+impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs
index 2d334b9..0b7286b 100644
--- a/src/database/key_value/rooms/timeline.rs
+++ b/src/database/key_value/rooms/timeline.rs
@@ -5,7 +5,27 @@ use tracing::error;
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent, Result, services};
-impl service::rooms::timeline::Data for Arc<KeyValueDatabase> {
+impl service::rooms::timeline::Data for KeyValueDatabase {
+ fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
+ let prefix = services().rooms.short
+ .get_shortroomid(room_id)?
+ .expect("room exists")
+ .to_be_bytes()
+ .to_vec();
+
+ // Look for PDUs in that room.
+ self.pduid_pdu
+ .iter_from(&prefix, false)
+ .filter(|(k, _)| k.starts_with(&prefix))
+ .map(|(_, pdu)| {
+ serde_json::from_slice(&pdu)
+ .map_err(|_| Error::bad_database("Invalid first PDU in db."))
+ .map(Arc::new)
+ })
+ .next()
+ .transpose()
+ }
+
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
match self
.lasttimelinecount_cache
diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs
index 4d20b00..3759bda 100644
--- a/src/database/key_value/rooms/user.rs
+++ b/src/database/key_value/rooms/user.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, RoomId};
use crate::{service, database::KeyValueDatabase, utils, Error, Result, services};
-impl service::rooms::user::Data for Arc<KeyValueDatabase> {
+impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff);
@@ -104,13 +102,13 @@ impl service::rooms::user::Data for Arc<KeyValueDatabase> {
});
// We use the default compare function because keys are sorted correctly (not reversed)
- Ok(utils::common_elements(iterators, Ord::cmp)
+ Ok(Box::new(Box::new(utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(utils::string_from_bytes(&*bytes).map_err(|_| {
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
})?)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
- }))
+ }))))
}
}
diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs
index 7fa6908..a63b3c5 100644
--- a/src/database/key_value/transaction_ids.rs
+++ b/src/database/key_value/transaction_ids.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, TransactionId};
use crate::{service, database::KeyValueDatabase, Result};
-impl service::transaction_ids::Data for Arc<KeyValueDatabase> {
+impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs
index 8752e55..cf242de 100644
--- a/src/database/key_value/uiaa.rs
+++ b/src/database/key_value/uiaa.rs
@@ -1,10 +1,8 @@
-use std::sync::Arc;
-
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::{uiaa::UiaaInfo, error::ErrorKind}};
use crate::{database::KeyValueDatabase, service, Error, Result};
-impl service::uiaa::Data for Arc<KeyValueDatabase> {
+impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request(
&self,
user_id: &UserId,
diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs
index 1ac85b3..55a518d 100644
--- a/src/database/key_value/users.rs
+++ b/src/database/key_value/users.rs
@@ -1,11 +1,11 @@
-use std::{mem::size_of, collections::BTreeMap, sync::Arc};
+use std::{mem::size_of, collections::BTreeMap};
use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt};
use tracing::warn;
use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services, Result};
-impl service::users::Data for Arc<KeyValueDatabase> {
+impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver.
fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
@@ -113,7 +113,7 @@ impl service::users::Data for Arc<KeyValueDatabase> {
/// Hash and set the user's password to the Argon2 hash
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password {
- if let Ok(hash) = utils::calculate_hash(password) {
+ if let Ok(hash) = utils::calculate_password_hash(password) {
self.userid_password
.insert(user_id.as_bytes(), hash.as_bytes())?;
Ok(())
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 35922f0..6868467 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -238,8 +238,8 @@ impl KeyValueDatabase {
}
/// Load an existing database or create a new one.
- pub async fn load_or_create(config: &Config) -> Result<()> {
- Self::check_db_setup(config)?;
+ pub async fn load_or_create(config: Config) -> Result<()> {
+ Self::check_db_setup(&config)?;
if !Path::new(&config.database_path).exists() {
std::fs::create_dir_all(&config.database_path)
@@ -251,19 +251,19 @@ impl KeyValueDatabase {
#[cfg(not(feature = "sqlite"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "sqlite")]
- Arc::new(Arc::<abstraction::sqlite::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
}
"rocksdb" => {
#[cfg(not(feature = "rocksdb"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "rocksdb")]
- Arc::new(Arc::<abstraction::rocksdb::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
}
"persy" => {
#[cfg(not(feature = "persy"))]
return Err(Error::BadConfig("Database backend not found."));
#[cfg(feature = "persy")]
- Arc::new(Arc::<abstraction::persy::Engine>::open(config)?)
+ Arc::new(Arc::<abstraction::persy::Engine>::open(&config)?)
}
_ => {
return Err(Error::BadConfig("Database backend not found."));
@@ -402,7 +402,7 @@ impl KeyValueDatabase {
});
- let services_raw = Box::new(Services::build(Arc::clone(&db)));
+ let services_raw = Box::new(Services::build(Arc::clone(&db), config)?);
// This is the first and only time we initialize the SERVICE static
*SERVICES.write().unwrap() = Some(Box::leak(services_raw));
@@ -825,7 +825,7 @@ impl KeyValueDatabase {
info!(
"Loaded {} database with version {}",
- config.database_backend, latest_database_version
+ services().globals.config.database_backend, latest_database_version
);
} else {
services()
@@ -837,7 +837,7 @@ impl KeyValueDatabase {
warn!(
"Created new {} database with version {}",
- config.database_backend, latest_database_version
+ services().globals.config.database_backend, latest_database_version
);
}
@@ -866,7 +866,7 @@ impl KeyValueDatabase {
.sending
.start_handler(sending_receiver);
- Self::start_cleanup_task(config).await;
+ Self::start_cleanup_task().await;
Ok(())
}
@@ -888,8 +888,8 @@ impl KeyValueDatabase {
res
}
- #[tracing::instrument(skip(config))]
- pub async fn start_cleanup_task(config: &Config) {
+ #[tracing::instrument]
+ pub async fn start_cleanup_task() {
use tokio::time::interval;
#[cfg(unix)]
@@ -898,7 +898,7 @@ impl KeyValueDatabase {
use std::time::{Duration, Instant};
- let timer_interval = Duration::from_secs(config.cleanup_second_interval as u64);
+ let timer_interval = Duration::from_secs(services().globals.config.cleanup_second_interval as u64);
tokio::spawn(async move {
let mut i = interval(timer_interval);
diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs
index 9785478..1289f7a 100644
--- a/src/service/account_data/mod.rs
+++ b/src/service/account_data/mod.rs
@@ -18,7 +18,7 @@ use tracing::error;
use crate::{service::*, services, utils, Error, Result};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs
index 32a709c..0b14314 100644
--- a/src/service/admin/mod.rs
+++ b/src/service/admin/mod.rs
@@ -426,7 +426,7 @@ impl Service {
Error::bad_database("Invalid room id field in event in database")
})?;
let start = Instant::now();
- let count = server_server::get_auth_chain(room_id, vec![event_id])
+ let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id])
.await?
.count();
let elapsed = start.elapsed();
@@ -615,14 +615,12 @@ impl Service {
))
}
AdminCommand::DisableRoom { room_id } => {
- todo!();
- //services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?;
- //RoomMessageEventContent::text_plain("Room disabled.")
+ services().rooms.metadata.disable_room(&room_id, true);
+ RoomMessageEventContent::text_plain("Room disabled.")
}
AdminCommand::EnableRoom { room_id } => {
- todo!();
- //services().rooms.disabledroomids.remove(room_id.as_bytes())?;
- //RoomMessageEventContent::text_plain("Room enabled.")
+ services().rooms.metadata.disable_room(&room_id, false);
+ RoomMessageEventContent::text_plain("Room enabled.")
}
AdminCommand::DeactivateUser {
leave_rooms,
diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs
index 8fd69df..de8d1aa 100644
--- a/src/service/globals/mod.rs
+++ b/src/service/globals/mod.rs
@@ -35,7 +35,7 @@ type SyncHandle = (
);
pub struct Service {
- pub db: Box<dyn Data>,
+ pub db: Arc<dyn Data>,
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
@@ -92,7 +92,7 @@ impl Default for RotationHandler {
impl Service {
pub fn load(
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
config: Config,
) -> Result<Self> {
let keypair = db.load_keypair();
diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs
index 4bd9efd..a3bed71 100644
--- a/src/service/key_backups/mod.rs
+++ b/src/service/key_backups/mod.rs
@@ -13,7 +13,7 @@ use ruma::{
use std::{collections::BTreeMap, sync::Arc};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs
index f86251f..d3dd2bd 100644
--- a/src/service/media/mod.rs
+++ b/src/service/media/mod.rs
@@ -16,7 +16,7 @@ pub struct FileMeta {
}
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/mod.rs b/src/service/mod.rs
index a1a728c..a772c1d 100644
--- a/src/service/mod.rs
+++ b/src/service/mod.rs
@@ -1,4 +1,9 @@
-use std::sync::Arc;
+use std::{
+ collections::{BTreeMap, HashMap},
+ sync::{Arc, Mutex},
+};
+
+use crate::{Result, Config};
pub mod account_data;
pub mod admin;
@@ -30,20 +35,73 @@ pub struct Services {
}
impl Services {
- pub fn build<D: appservice::Data + pusher::Data + rooms::Data + transaction_ids::Data + uiaa::Data + users::Data + account_data::Data + globals::Data + key_backups::Data + media::Data>(db: Arc<D>) -> Self {
- Self {
+ pub fn build<
+ D: appservice::Data
+ + pusher::Data
+ + rooms::Data
+ + transaction_ids::Data
+ + uiaa::Data
+ + users::Data
+ + account_data::Data
+ + globals::Data
+ + key_backups::Data
+ + media::Data,
+ >(
+ db: Arc<D>, config: Config
+ ) -> Result<Self> {
+ Ok(Self {
appservice: appservice::Service { db: db.clone() },
pusher: pusher::Service { db: db.clone() },
- rooms: rooms::Service { db: Arc::clone(&db) },
- transaction_ids: transaction_ids::Service { db: Arc::clone(&db) },
- uiaa: uiaa::Service { db: Arc::clone(&db) },
- users: users::Service { db: Arc::clone(&db) },
- account_data: account_data::Service { db: Arc::clone(&db) },
- admin: admin::Service { db: Arc::clone(&db) },
- globals: globals::Service { db: Arc::clone(&db) },
- key_backups: key_backups::Service { db: Arc::clone(&db) },
- media: media::Service { db: Arc::clone(&db) },
- sending: sending::Service { db: Arc::clone(&db) },
- }
+ rooms: rooms::Service {
+ alias: rooms::alias::Service { db: db.clone() },
+ auth_chain: rooms::auth_chain::Service { db: db.clone() },
+ directory: rooms::directory::Service { db: db.clone() },
+ edus: rooms::edus::Service {
+ presence: rooms::edus::presence::Service { db: db.clone() },
+ read_receipt: rooms::edus::read_receipt::Service { db: db.clone() },
+ typing: rooms::edus::typing::Service { db: db.clone() },
+ },
+ event_handler: rooms::event_handler::Service,
+ lazy_loading: rooms::lazy_loading::Service {
+ db: db.clone(),
+ lazy_load_waiting: Mutex::new(HashMap::new()),
+ },
+ metadata: rooms::metadata::Service { db: db.clone() },
+ outlier: rooms::outlier::Service { db: db.clone() },
+ pdu_metadata: rooms::pdu_metadata::Service { db: db.clone() },
+ search: rooms::search::Service { db: db.clone() },
+ short: rooms::short::Service { db: db.clone() },
+ state: rooms::state::Service { db: db.clone() },
+ state_accessor: rooms::state_accessor::Service { db: db.clone() },
+ state_cache: rooms::state_cache::Service { db: db.clone() },
+ state_compressor: rooms::state_compressor::Service { db: db.clone() },
+ timeline: rooms::timeline::Service { db: db.clone() },
+ user: rooms::user::Service { db: db.clone() },
+ },
+ transaction_ids: transaction_ids::Service {
+ db: db.clone()
+ },
+ uiaa: uiaa::Service {
+ db: db.clone()
+ },
+ users: users::Service {
+ db: db.clone()
+ },
+ account_data: account_data::Service {
+ db: db.clone()
+ },
+ admin: admin::Service { sender: todo!() },
+ globals: globals::Service::load(db.clone(), config)?,
+ key_backups: key_backups::Service {
+ db: db.clone()
+ },
+ media: media::Service {
+ db: db.clone()
+ },
+ sending: sending::Service {
+ maximum_requests: todo!(),
+ sender: todo!(),
+ },
+ })
}
}
diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs
index ef5888f..65fb367 100644
--- a/src/service/rooms/alias/mod.rs
+++ b/src/service/rooms/alias/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{RoomAliasId, RoomId};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs
index 5fe0e3e..e35094b 100644
--- a/src/service/rooms/auth_chain/mod.rs
+++ b/src/service/rooms/auth_chain/mod.rs
@@ -1,12 +1,14 @@
mod data;
-use std::{sync::Arc, collections::HashSet};
+use std::{sync::Arc, collections::{HashSet, BTreeSet}};
pub use data::Data;
+use ruma::{RoomId, EventId, api::client::error::ErrorKind};
+use tracing::log::warn;
-use crate::Result;
+use crate::{Result, services, Error};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -22,4 +24,131 @@ impl Service {
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()> {
self.db.cache_auth_chain(key, auth_chain)
}
+
+ #[tracing::instrument(skip(self, starting_events))]
+ pub async fn get_auth_chain<'a>(
+ &self,
+ room_id: &RoomId,
+ starting_events: Vec<Arc<EventId>>,
+ ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
+ const NUM_BUCKETS: usize = 50;
+
+ let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS];
+
+ let mut i = 0;
+ for id in starting_events {
+ let short = services().rooms.short.get_or_create_shorteventid(&id)?;
+ let bucket_id = (short % NUM_BUCKETS as u64) as usize;
+ buckets[bucket_id].insert((short, id.clone()));
+ i += 1;
+ if i % 100 == 0 {
+ tokio::task::yield_now().await;
+ }
+ }
+
+ let mut full_auth_chain = HashSet::new();
+
+ let mut hits = 0;
+ let mut misses = 0;
+ for chunk in buckets {
+ if chunk.is_empty() {
+ continue;
+ }
+
+ let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
+ if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
+ hits += 1;
+ full_auth_chain.extend(cached.iter().copied());
+ continue;
+ }
+ misses += 1;
+
+ let mut chunk_cache = HashSet::new();
+ let mut hits2 = 0;
+ let mut misses2 = 0;
+ let mut i = 0;
+ for (sevent_id, event_id) in chunk {
+ if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
+ hits2 += 1;
+ chunk_cache.extend(cached.iter().copied());
+ } else {
+ misses2 += 1;
+ let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
+ services().rooms
+ .auth_chain
+ .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
+ println!(
+ "cache missed event {} with auth chain len {}",
+ event_id,
+ auth_chain.len()
+ );
+ chunk_cache.extend(auth_chain.iter());
+
+ i += 1;
+ if i % 100 == 0 {
+ tokio::task::yield_now().await;
+ }
+ };
+ }
+ println!(
+ "chunk missed with len {}, event hits2: {}, misses2: {}",
+ chunk_cache.len(),
+ hits2,
+ misses2
+ );
+ let chunk_cache = Arc::new(chunk_cache);
+ services().rooms
+ .auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
+ full_auth_chain.extend(chunk_cache.iter());
+ }
+
+ println!(
+ "total: {}, chunk hits: {}, misses: {}",
+ full_auth_chain.len(),
+ hits,
+ misses
+ );
+
+ Ok(full_auth_chain
+ .into_iter()
+ .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
+ }
+
+ #[tracing::instrument(skip(self, event_id))]
+ fn get_auth_chain_inner(
+ &self,
+ room_id: &RoomId,
+ event_id: &EventId,
+ ) -> Result<HashSet<u64>> {
+ let mut todo = vec![Arc::from(event_id)];
+ let mut found = HashSet::new();
+
+ while let Some(event_id) = todo.pop() {
+ match services().rooms.timeline.get_pdu(&event_id) {
+ Ok(Some(pdu)) => {
+ if pdu.room_id != room_id {
+ return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
+ }
+ for auth_event in &pdu.auth_events {
+ let sauthevent = services()
+ .rooms.short
+ .get_or_create_shorteventid(auth_event)?;
+
+ if !found.contains(&sauthevent) {
+ found.insert(sauthevent);
+ todo.push(auth_event.clone());
+ }
+ }
+ }
+ Ok(None) => {
+ warn!("Could not find pdu mentioned in auth events: {}", event_id);
+ }
+ Err(e) => {
+ warn!("Could not load event in auth chain: {} {}", event_id, e);
+ }
+ }
+ }
+
+ Ok(found)
+ }
}
diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs
index fb28994..e85afef 100644
--- a/src/service/rooms/directory/mod.rs
+++ b/src/service/rooms/directory/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::RoomId;
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs
index 73b7b5a..d657897 100644
--- a/src/service/rooms/edus/presence/mod.rs
+++ b/src/service/rooms/edus/presence/mod.rs
@@ -1,5 +1,5 @@
mod data;
-use std::collections::HashMap;
+use std::{collections::HashMap, sync::Arc};
pub use data::Data;
use ruma::{RoomId, UserId, events::presence::PresenceEvent};
@@ -7,7 +7,7 @@ use ruma::{RoomId, UserId, events::presence::PresenceEvent};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs
index 2a4c0b7..1770877 100644
--- a/src/service/rooms/edus/read_receipt/mod.rs
+++ b/src/service/rooms/edus/read_receipt/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs
index 16a135f..3752056 100644
--- a/src/service/rooms/edus/typing/mod.rs
+++ b/src/service/rooms/edus/typing/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{UserId, RoomId, events::SyncEphemeralRoomEvent};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs
index ac3cca6..79f93b5 100644
--- a/src/service/rooms/event_handler/mod.rs
+++ b/src/service/rooms/event_handler/mod.rs
@@ -72,13 +72,15 @@ impl Service {
));
}
- services()
+ if services()
.rooms
- .is_disabled(room_id)?
- .ok_or(Error::BadRequest(
+ .metadata
+ .is_disabled(room_id)? {
+ return Err(Error::BadRequest(
ErrorKind::Forbidden,
"Federation of this room is currently disabled on this server.",
- ))?;
+ ));
+ }
// 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? {
@@ -111,7 +113,7 @@ impl Service {
}
// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
- let (sorted_prev_events, eventid_info) = self.fetch_unknown_prev_events(
+ let (sorted_prev_events, mut eventid_info) = self.fetch_unknown_prev_events(
origin,
&create_event,
room_id,
@@ -122,14 +124,15 @@ impl Service {
let mut errors = 0;
for prev_id in dbg!(sorted_prev_events) {
// Check for disabled again because it might have changed
- services()
+ if services()
.rooms
- .is_disabled(room_id)?
- .ok_or(Error::BadRequest(
+ .metadata
+ .is_disabled(room_id)? {
+ return Err(Error::BadRequest(
ErrorKind::Forbidden,
- "Federation of
- this room is currently disabled on this server.",
- ))?;
+ "Federation of this room is currently disabled on this server.",
+ ));
+ }
if let Some((time, tries)) = services()
.globals
@@ -279,14 +282,14 @@ impl Service {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e);
- return Err("Signature verification failed".to_owned());
+ return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed"));
}
Ok(ruma::signatures::Verified::Signatures) => {
// Redact
warn!("Calculated hash does not match: {}", event_id);
match ruma::signatures::redact(&value, room_version_id) {
Ok(obj) => obj,
- Err(_) => return Err("Redaction failed".to_owned()),
+ Err(_) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")),
}
}
Ok(ruma::signatures::Verified::All) => value,
@@ -480,7 +483,7 @@ impl Service {
let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events {
- let prev_event = if let Ok(Some(pdu)) = services().rooms.get_pdu(prev_eventid) {
+ let prev_event = if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) {
pdu
} else {
okay = false;
@@ -488,7 +491,7 @@ impl Service {
};
let sstatehash =
- if let Ok(Some(s)) = services().rooms.pdu_shortstatehash(prev_eventid) {
+ if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) {
s
} else {
okay = false;
@@ -525,7 +528,7 @@ impl Service {
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state {
- if let Ok((ty, st_key)) = services().rooms.get_statekey_from_short(k) {
+ if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) {
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone());
@@ -539,7 +542,7 @@ impl Service {
services()
.rooms
.auth_chain
- .get_auth_chain(room_id, starting_events, services())
+ .get_auth_chain(room_id, starting_events)
.await?
.collect(),
);
@@ -551,7 +554,7 @@ impl Service {
let result =
state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
- let res = services().rooms.get_pdu(id);
+ let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
@@ -677,7 +680,7 @@ impl Service {
.and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten())
},
)
- .map_err(|_e| "Auth check failed.".to_owned())?;
+ .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?;
if !check_result {
return Err(Error::bad_database("Event has failed auth check with state at the event."));
@@ -714,7 +717,7 @@ impl Service {
// Only keep those extremities were not referenced yet
extremities
- .retain(|id| !matches!(services().rooms.is_event_referenced(room_id, id), Ok(true)));
+ .retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true)));
info!("Compressing state at event");
let state_ids_compressed = state_at_incoming_event
@@ -722,7 +725,8 @@ impl Service {
.map(|(shortstatekey, id)| {
services()
.rooms
- .compress_state_event(*shortstatekey, id)?
+ .state_compressor
+ .compress_state_event(*shortstatekey, id)
})
.collect::<Result<_>>()?;
@@ -731,6 +735,7 @@ impl Service {
let auth_events = services()
.rooms
+ .state
.get_auth_events(
room_id,
&incoming_pdu.kind,
@@ -744,10 +749,10 @@ impl Service {
&incoming_pdu,
None::<PduEvent>,
|k, s| auth_events.get(&(k.clone(), s.to_owned())),
- )?;
+ ).map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?;
if soft_fail {
- self.append_incoming_pdu(
+ services().rooms.timeline.append_incoming_pdu(
&incoming_pdu,
val,
extremities.iter().map(std::ops::Deref::deref),
@@ -760,8 +765,9 @@ impl Service {
warn!("Event was soft failed: {:?}", incoming_pdu);
services()
.rooms
+ .pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id)?;
- return Err("Event has been soft failed".into());
+ return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed"));
}
if incoming_pdu.state_key.is_some() {
@@ -798,14 +804,14 @@ impl Service {
"Found extremity pdu with no statehash in db: {:?}",
leaf_pdu
);
- "Found pdu with no statehash in db.".to_owned()
+ Error::bad_database("Found pdu with no statehash in db.")
})?,
leaf_pdu,
);
}
_ => {
error!("Missing state snapshot for {:?}", id);
- return Err("Missing state snapshot.".to_owned());
+ return Err(Error::BadDatabase("Missing state snapshot."));
}
}
}
@@ -835,7 +841,7 @@ impl Service {
let mut update_state = false;
// 14. Use state resolution to find new room state
let new_room_state = if fork_states.is_empty() {
- return Err("State is empty.".to_owned());
+ panic!("State is empty");
} else if fork_states.iter().skip(1).all(|f| &fork_states[0] == f) {
info!("State resolution trivial");
// There was only one state, so it has to be the room's current state (because that is
@@ -845,7 +851,8 @@ impl Service {
.map(|(k, id)| {
services()
.rooms
- .compress_state_event(*k, id)?
+ .state_compressor
+ .compress_state_event(*k, id)
})
.collect::<Result<_>>()?
} else {
@@ -877,9 +884,8 @@ impl Service {
.filter_map(|(k, id)| {
services()
.rooms
- .get_statekey_from_short(k)?
- // FIXME: Undo .to_string().into() when StateMap
- // is updated to use StateEventType
+ .short
+ .get_statekey_from_short(k)
.map(|(ty, st_key)| ((ty.to_string().into(), st_key), id))
.ok()
})
@@ -895,7 +901,7 @@ impl Service {
&fork_states,
auth_chain_sets,
|id| {
- let res = services().rooms.get_pdu(id);
+ let res = services().rooms.timeline.get_pdu(id);
if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e);
}
@@ -904,7 +910,7 @@ impl Service {
) {
Ok(new_state) => new_state,
Err(_) => {
- return Err("State resolution failed, either an event could not be found or deserialization".into());
+ return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization"));
}
};
@@ -921,6 +927,7 @@ impl Service {
.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?;
services()
.rooms
+ .state_compressor
.compress_state_event(shortstatekey, &event_id)
})
.collect::<Result<_>>()?
@@ -929,9 +936,11 @@ impl Service {
// Set the new room state to the resolved state
if update_state {
info!("Forcing new room state");
+ let (sstatehash, _, _) = services().rooms.state_compressor.save_state(room_id, new_room_state)?;
services()
.rooms
- .force_state(room_id, new_room_state)?;
+ .state
+ .set_room_state(room_id, sstatehash, &state_lock)?;
}
}
@@ -942,7 +951,7 @@ impl Service {
// We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event.
- let pdu_id = self
+ let pdu_id = services().rooms.timeline
.append_incoming_pdu(
&incoming_pdu,
val,
@@ -1017,7 +1026,7 @@ impl Service {
// a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree
// (get_pdu_json checks both)
- if let Ok(Some(local_pdu)) = services().rooms.get_pdu(id) {
+ if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) {
trace!("Found {} in db", id);
pdus.push((local_pdu, None));
continue;
@@ -1040,7 +1049,7 @@ impl Service {
tokio::task::yield_now().await;
}
- if let Ok(Some(_)) = services().rooms.get_pdu(&next_id) {
+ if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) {
trace!("Found {} in db", id);
continue;
}
@@ -1140,6 +1149,7 @@ impl Service {
let first_pdu_in_room = services()
.rooms
+ .timeline
.first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs
index 90dad21..760fffe 100644
--- a/src/service/rooms/lazy_loading/mod.rs
+++ b/src/service/rooms/lazy_loading/mod.rs
@@ -1,5 +1,5 @@
mod data;
-use std::{collections::{HashSet, HashMap}, sync::Mutex};
+use std::{collections::{HashSet, HashMap}, sync::{Mutex, Arc}};
pub use data::Data;
use ruma::{DeviceId, UserId, RoomId};
@@ -7,7 +7,7 @@ use ruma::{DeviceId, UserId, RoomId};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
lazy_load_waiting: Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>,
}
diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs
index 9444db4..bc31ee8 100644
--- a/src/service/rooms/metadata/data.rs
+++ b/src/service/rooms/metadata/data.rs
@@ -3,4 +3,6 @@ use crate::Result;
pub trait Data: Send + Sync {
fn exists(&self, room_id: &RoomId) -> Result<bool>;
+ fn is_disabled(&self, room_id: &RoomId) -> Result<bool>;
+ fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>;
}
diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs
index 3c21dd1..b6cccd1 100644
--- a/src/service/rooms/metadata/mod.rs
+++ b/src/service/rooms/metadata/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::RoomId;
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -14,4 +16,12 @@ impl Service {
pub fn exists(&self, room_id: &RoomId) -> Result<bool> {
self.db.exists(room_id)
}
+
+ pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
+ self.db.is_disabled(room_id)
+ }
+
+ pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
+ self.db.disable_room(room_id, disabled)
+ }
}
diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs
index 5493ce4..d36adc4 100644
--- a/src/service/rooms/outlier/mod.rs
+++ b/src/service/rooms/outlier/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{Result, PduEvent};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs
index a81d05c..4724f85 100644
--- a/src/service/rooms/pdu_metadata/mod.rs
+++ b/src/service/rooms/pdu_metadata/mod.rs
@@ -7,7 +7,7 @@ use ruma::{RoomId, EventId};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs
index dc57191..ec1ad53 100644
--- a/src/service/rooms/search/mod.rs
+++ b/src/service/rooms/search/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use crate::Result;
use ruma::RoomId;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs
index bc2b28f..07a2712 100644
--- a/src/service/rooms/short/data.rs
+++ b/src/service/rooms/short/data.rs
@@ -1,2 +1,40 @@
+use std::sync::Arc;
+
+use ruma::{EventId, events::StateEventType, RoomId};
+use crate::Result;
+
pub trait Data: Send + Sync {
+ fn get_or_create_shorteventid(
+ &self,
+ event_id: &EventId,
+ ) -> Result<u64>;
+
+ fn get_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<Option<u64>>;
+
+ fn get_or_create_shortstatekey(
+ &self,
+ event_type: &StateEventType,
+ state_key: &str,
+ ) -> Result<u64>;
+
+ fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>>;
+
+ fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>;
+
+ /// Returns (shortstatehash, already_existed)
+ fn get_or_create_shortstatehash(
+ &self,
+ state_hash: &[u8],
+ ) -> Result<(u64, bool)>;
+
+ fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>>;
+
+ fn get_or_create_shortroomid(
+ &self,
+ room_id: &RoomId,
+ ) -> Result<u64>;
}
diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs
index a024dc6..08ce5c5 100644
--- a/src/service/rooms/short/mod.rs
+++ b/src/service/rooms/short/mod.rs
@@ -7,7 +7,7 @@ use ruma::{EventId, events::StateEventType, RoomId};
use crate::{Result, Error, utils, services};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -15,29 +15,7 @@ impl Service {
&self,
event_id: &EventId,
) -> Result<u64> {
- if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) {
- return Ok(*short);
- }
-
- let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
- Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
- .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
- None => {
- let shorteventid = services().globals.next_count()?;
- self.eventid_shorteventid
- .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
- self.shorteventid_eventid
- .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
- shorteventid
- }
- };
-
- self.eventidshort_cache
- .lock()
- .unwrap()
- .insert(event_id.to_owned(), short);
-
- Ok(short)
+ self.db.get_or_create_shorteventid(event_id)
}
pub fn get_shortstatekey(
@@ -45,36 +23,7 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<u64>> {
- if let Some(short) = self
- .statekeyshort_cache
- .lock()
- .unwrap()
- .get_mut(&(event_type.clone(), state_key.to_owned()))
- {
- return Ok(Some(*short));
- }
-
- let mut statekey = event_type.to_string().as_bytes().to_vec();
- statekey.push(0xff);
- statekey.extend_from_slice(state_key.as_bytes());
-
- let short = self
- .statekey_shortstatekey
- .get(&statekey)?
- .map(|shortstatekey| {
- utils::u64_from_bytes(&shortstatekey)
- .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
- })
- .transpose()?;
-
- if let Some(s) = short {
- self.statekeyshort_cache
- .lock()
- .unwrap()
- .insert((event_type.clone(), state_key.to_owned()), s);
- }
-
- Ok(short)
+ self.db.get_shortstatekey(event_type, state_key)
}
pub fn get_or_create_shortstatekey(
@@ -82,152 +31,33 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<u64> {
- if let Some(short) = self
- .statekeyshort_cache
- .lock()
- .unwrap()
- .get_mut(&(event_type.clone(), state_key.to_owned()))
- {
- return Ok(*short);
- }
-
- let mut statekey = event_type.to_string().as_bytes().to_vec();
- statekey.push(0xff);
- statekey.extend_from_slice(state_key.as_bytes());
-
- let short = match self.statekey_shortstatekey.get(&statekey)? {
- Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey)
- .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
- None => {
- let shortstatekey = services().globals.next_count()?;
- self.statekey_shortstatekey
- .insert(&statekey, &shortstatekey.to_be_bytes())?;
- self.shortstatekey_statekey
- .insert(&shortstatekey.to_be_bytes(), &statekey)?;
- shortstatekey
- }
- };
-
- self.statekeyshort_cache
- .lock()
- .unwrap()
- .insert((event_type.clone(), state_key.to_owned()), short);
-
- Ok(short)
+ self.db.get_or_create_shortstatekey(event_type, state_key)
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
- if let Some(id) = self
- .shorteventid_cache
- .lock()
- .unwrap()
- .get_mut(&shorteventid)
- {
- return Ok(Arc::clone(id));
- }
-
- let bytes = self
- .shorteventid_eventid
- .get(&shorteventid.to_be_bytes())?
- .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
-
- let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
- Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
- })?)
- .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
-
- self.shorteventid_cache
- .lock()
- .unwrap()
- .insert(shorteventid, Arc::clone(&event_id));
-
- Ok(event_id)
+ self.db.get_eventid_from_short(shorteventid)
}
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
- if let Some(id) = self
- .shortstatekey_cache
- .lock()
- .unwrap()
- .get_mut(&shortstatekey)
- {
- return Ok(id.clone());
- }
-
- let bytes = self
- .shortstatekey_statekey
- .get(&shortstatekey.to_be_bytes())?
- .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
-
- let mut parts = bytes.splitn(2, |&b| b == 0xff);
- let eventtype_bytes = parts.next().expect("split always returns one entry");
- let statekey_bytes = parts
- .next()
- .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
-
- let event_type =
- StateEventType::try_from(utils::string_from_bytes(eventtype_bytes).map_err(|_| {
- Error::bad_database("Event type in shortstatekey_statekey is invalid unicode.")
- })?)
- .map_err(|_| Error::bad_database("Event type in shortstatekey_statekey is invalid."))?;
-
- let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
- Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
- })?;
-
- let result = (event_type, state_key);
-
- self.shortstatekey_cache
- .lock()
- .unwrap()
- .insert(shortstatekey, result.clone());
-
- Ok(result)
+ self.db.get_statekey_from_short(shortstatekey)
}
/// Returns (shortstatehash, already_existed)
- fn get_or_create_shortstatehash(
+ pub fn get_or_create_shortstatehash(
&self,
state_hash: &[u8],
) -> Result<(u64, bool)> {
- Ok(match self.statehash_shortstatehash.get(state_hash)? {
- Some(shortstatehash) => (
- utils::u64_from_bytes(&shortstatehash)
- .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
- true,
- ),
- None => {
- let shortstatehash = services().globals.next_count()?;
- self.statehash_shortstatehash
- .insert(state_hash, &shortstatehash.to_be_bytes())?;
- (shortstatehash, false)
- }
- })
+ self.db.get_or_create_shortstatehash(state_hash)
}
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
- self.roomid_shortroomid
- .get(room_id.as_bytes())?
- .map(|bytes| {
- utils::u64_from_bytes(&bytes)
- .map_err(|_| Error::bad_database("Invalid shortroomid in db."))
- })
- .transpose()
+ self.db.get_shortroomid(room_id)
}
pub fn get_or_create_shortroomid(
&self,
room_id: &RoomId,
) -> Result<u64> {
- Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
- Some(short) => utils::u64_from_bytes(&short)
- .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
- None => {
- let short = services().globals.next_count()?;
- self.roomid_shortroomid
- .insert(room_id.as_bytes(), &short.to_be_bytes())?;
- short
- }
- })
+ self.db.get_or_create_shortroomid(room_id)
}
}
diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs
index 5385978..79807c5 100644
--- a/src/service/rooms/state/mod.rs
+++ b/src/service/rooms/state/mod.rs
@@ -1,9 +1,10 @@
mod data;
-use std::{collections::HashSet, sync::Arc};
+use std::{collections::{HashSet, HashMap}, sync::Arc};
pub use data::Data;
-use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType}, UserId, EventId, serde::Raw, RoomVersionId};
+use ruma::{RoomId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, RoomEventType}, UserId, EventId, serde::Raw, RoomVersionId, state_res::{StateMap, self}};
use serde::Deserialize;
+use tokio::sync::MutexGuard;
use tracing::warn;
use crate::{Result, services, PduEvent, Error, utils::calculate_hash};
@@ -11,7 +12,7 @@ use crate::{Result, services, PduEvent, Error, utils::calculate_hash};
use super::state_compressor::CompressedStateEvent;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -97,7 +98,7 @@ impl Service {
room_id: &RoomId,
state_ids_compressed: HashSet<CompressedStateEvent>,
) -> Result<u64> {
- let shorteventid = services().short.get_or_create_shorteventid(event_id)?;
+ let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
@@ -109,11 +110,11 @@ impl Service {
);
let (shortstatehash, already_existed) =
- services().short.get_or_create_shortstatehash(&state_hash)?;
+ services().rooms.short.get_or_create_shortstatehash(&state_hash)?;
if !already_existed {
let states_parents = previous_shortstatehash
- .map_or_else(|| Ok(Vec::new()), |p| services().room.state_compressor.load_shortstatehash_info(p))?;
+ .map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?;
let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() {
@@ -132,7 +133,7 @@ impl Service {
} else {
(state_ids_compressed, HashSet::new())
};
- services().room.state_compressor.save_state_from_diff(
+ services().rooms.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@@ -141,7 +142,7 @@ impl Service {
)?;
}
- self.db.set_event_state(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
+ self.db.set_event_state(shorteventid, shortstatehash)?;
Ok(shortstatehash)
}
@@ -155,25 +156,24 @@ impl Service {
&self,
new_pdu: &PduEvent,
) -> Result<u64> {
- let shorteventid = self.get_or_create_shorteventid(&new_pdu.event_id)?;
+ let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
if let Some(p) = previous_shortstatehash {
- self.shorteventid_shortstatehash
- .insert(&shorteventid.to_be_bytes(), &p.to_be_bytes())?;
+ self.db.set_event_state(shorteventid, p)?;
}
if let Some(state_key) = &new_pdu.state_key {
let states_parents = previous_shortstatehash
- .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
+ .map_or_else(|| Ok(Vec::new()), |p| services().rooms.state_compressor.load_shortstatehash_info(p))?;
- let shortstatekey = self.get_or_create_shortstatekey(
+ let shortstatekey = services().rooms.short.get_or_create_shortstatekey(
&new_pdu.kind.to_string().into(),
state_key,
)?;
- let new = self.compress_state_event(shortstatekey, &new_pdu.event_id)?;
+ let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?;
let replaces = states_parents
.last()
@@ -199,7 +199,7 @@ impl Service {
statediffremoved.insert(*replaces);
}
- self.save_state_from_diff(
+ services().rooms.state_compressor.save_state_from_diff(
shortstatehash,
statediffnew,
statediffremoved,
@@ -221,16 +221,16 @@ impl Service {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
- self.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
+ services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
- self.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
+ services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
{
state.push(e.to_stripped_state_event());
}
- if let Some(e) = self.room_state_get(
+ if let Some(e) = services().rooms.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomCanonicalAlias,
"",
@@ -238,16 +238,16 @@ impl Service {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
- self.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
+ services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
- self.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
+ services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
{
state.push(e.to_stripped_state_event());
}
- if let Some(e) = self.room_state_get(
+ if let Some(e) = services().rooms.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomMember,
invite_event.sender.as_str(),
@@ -260,17 +260,16 @@ impl Service {
}
#[tracing::instrument(skip(self))]
- pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64) -> Result<()> {
- self.roomid_shortstatehash
- .insert(room_id.as_bytes(), &shortstatehash.to_be_bytes())?;
-
- Ok(())
+ pub fn set_room_state(&self, room_id: &RoomId, shortstatehash: u64,
+ mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
+ ) -> Result<()> {
+ self.db.set_room_state(room_id, shortstatehash, mutex_lock)
}
/// Returns the room's version.
#[tracing::instrument(skip(self))]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
- let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
+ let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: Option<RoomCreateEventContent> = create_event
.as_ref()
@@ -294,4 +293,50 @@ impl Service {
pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
self.db.get_forward_extremities(room_id)
}
+
+ /// This fetches auth events from the current state.
+ #[tracing::instrument(skip(self))]
+ pub fn get_auth_events(
+ &self,
+ room_id: &RoomId,
+ kind: &RoomEventType,
+ sender: &UserId,
+ state_key: Option<&str>,
+ content: &serde_json::value::RawValue,
+ ) -> Result<StateMap<Arc<PduEvent>>> {
+ let shortstatehash =
+ if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
+ current_shortstatehash
+ } else {
+ return Ok(HashMap::new());
+ };
+
+ let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)
+ .expect("content is a valid JSON object");
+
+ let mut sauthevents = auth_events
+ .into_iter()
+ .filter_map(|(event_type, state_key)| {
+ services().rooms.short.get_shortstatekey(&event_type.to_string().into(), &state_key)
+ .ok()
+ .flatten()
+ .map(|s| (s, (event_type, state_key)))
+ })
+ .collect::<HashMap<_, _>>();
+
+ let full_state = services().rooms.state_compressor
+ .load_shortstatehash_info(shortstatehash)?
+ .pop()
+ .expect("there is always one layer")
+ .1;
+
+ Ok(full_state
+ .into_iter()
+ .filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok())
+ .filter_map(|(shortstatekey, event_id)| {
+ sauthevents.remove(&shortstatekey).map(|k| (k, event_id))
+ })
+ .filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu)))
+ .collect())
+ }
}
diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs
index 1911e52..fd29948 100644
--- a/src/service/rooms/state_accessor/mod.rs
+++ b/src/service/rooms/state_accessor/mod.rs
@@ -7,7 +7,7 @@ use ruma::{events::StateEventType, RoomId, EventId};
use crate::{Result, PduEvent};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -45,7 +45,7 @@ impl Service {
event_type: &StateEventType,
state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
- self.db.pdu_state_get(shortstatehash, event_type, state_key)
+ self.db.state_get(shortstatehash, event_type, state_key)
}
/// Returns the state hash for this pdu.
diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs
index 18d1123..ab6a0d6 100644
--- a/src/service/rooms/state_cache/mod.rs
+++ b/src/service/rooms/state_cache/mod.rs
@@ -3,12 +3,23 @@ use std::{collections::HashSet, sync::Arc};
pub use data::Data;
use regex::Regex;
-use ruma::{RoomId, UserId, events::{room::{member::MembershipState, create::RoomCreateEventContent}, AnyStrippedStateEvent, StateEventType, tag::TagEvent, RoomAccountDataEventType, GlobalAccountDataEventType, direct::DirectEvent, ignored_user_list::IgnoredUserListEvent, AnySyncStateEvent}, serde::Raw, ServerName};
-
-use crate::{Result, services, utils, Error};
+use ruma::{
+ events::{
+ direct::{DirectEvent, DirectEventContent},
+ ignored_user_list::IgnoredUserListEvent,
+ room::{create::RoomCreateEventContent, member::MembershipState},
+ tag::{TagEvent, TagEventContent},
+ AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
+ RoomAccountDataEventType, StateEventType, RoomAccountDataEvent, RoomAccountDataEventContent,
+ },
+ serde::Raw,
+ RoomId, ServerName, UserId,
+};
+
+use crate::{services, utils, Error, Result};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
@@ -45,7 +56,9 @@ impl Service {
self.db.mark_as_once_joined(user_id, room_id)?;
// Check if the room has a predecessor
- if let Some(predecessor) = self
+ if let Some(predecessor) = services()
+ .rooms
+ .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok())
.and_then(|content: RoomCreateEventContent| content.predecessor)
@@ -76,27 +89,41 @@ impl Service {
// .ok();
// Copy old tags to new room
- if let Some(tag_event) = services().account_data.get::<TagEvent>(
- Some(&predecessor.room_id),
- user_id,
- RoomAccountDataEventType::Tag,
- )? {
- services().account_data
+ if let Some(tag_event) = services()
+ .account_data
+ .get(
+ Some(&predecessor.room_id),
+ user_id,
+ RoomAccountDataEventType::Tag,
+ )?
+ .map(|event| {
+ serde_json::from_str(event.get())
+ .map_err(|_| Error::bad_database("Invalid account data event in db."))
+ })
+ {
+ services()
+ .account_data
.update(
Some(room_id),
user_id,
RoomAccountDataEventType::Tag,
- &tag_event,
+ &tag_event?,
)
.ok();
};
// Copy direct chat flag
- if let Some(mut direct_event) = services().account_data.get::<DirectEvent>(
+ if let Some(mut direct_event) = services().account_data.get(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
- )? {
+ )?
+ .map(|event| {
+ serde_json::from_str::<DirectEvent>(event.get())
+ .map_err(|_| Error::bad_database("Invalid account data event in db."))
+ })
+ {
+ let direct_event = direct_event?;
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut() {
@@ -111,7 +138,7 @@ impl Service {
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
- &direct_event,
+ &serde_json::to_value(&direct_event).expect("to json always works"),
)?;
}
};
@@ -124,13 +151,17 @@ impl Service {
// We want to know if the sender is ignored by the receiver
let is_ignored = services()
.account_data
- .get::<IgnoredUserListEvent>(
+ .get(
None, // Ignored users are in global account data
user_id, // Receiver
GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)?
+ .map(|event| {
+ serde_json::from_str::<IgnoredUserListEvent>(event.get())
+ .map_err(|_| Error::bad_database("Invalid account data event in db."))
+ }).transpose()?
.map_or(false, |ignored| {
ignored
.content
@@ -200,10 +231,7 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
- pub fn get_our_real_users(
- &self,
- room_id: &RoomId,
- ) -> Result<Arc<HashSet<Box<UserId>>>> {
+ pub fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<Box<UserId>>>> {
let maybe = self
.our_real_users_cache
.read()
diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs
index ab9f427..0c32c4b 100644
--- a/src/service/rooms/state_compressor/mod.rs
+++ b/src/service/rooms/state_compressor/mod.rs
@@ -9,7 +9,7 @@ use crate::{Result, utils, services};
use self::data::StateDiff;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
@@ -67,7 +67,7 @@ impl Service {
) -> Result<CompressedStateEvent> {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
- &self
+ &services().rooms.short
.get_or_create_shorteventid(event_id)?
.to_be_bytes(),
);
@@ -218,7 +218,7 @@ impl Service {
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>)> // removed
{
- let previous_shortstatehash = self.db.current_shortstatehash(room_id)?;
+ let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?;
let state_hash = utils::calculate_hash(
&new_state_ids_compressed
diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs
index d073e86..2220b5f 100644
--- a/src/service/rooms/timeline/data.rs
+++ b/src/service/rooms/timeline/data.rs
@@ -5,6 +5,7 @@ use ruma::{signatures::CanonicalJsonObject, EventId, UserId, RoomId};
use crate::{Result, PduEvent};
pub trait Data: Send + Sync {
+ fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>>;
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64>;
/// Returns the `count` of this pdu's id.
diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs
index e8f4205..7817225 100644
--- a/src/service/rooms/timeline/mod.rs
+++ b/src/service/rooms/timeline/mod.rs
@@ -21,33 +21,14 @@ use crate::{services, Result, service::pdu::{PduBuilder, EventHash}, Error, PduE
use super::state_compressor::CompressedStateEvent;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
- /*
- /// Checks if a room exists.
#[tracing::instrument(skip(self))]
pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result<Option<Arc<PduEvent>>> {
- let prefix = self
- .get_shortroomid(room_id)?
- .expect("room exists")
- .to_be_bytes()
- .to_vec();
-
- // Look for PDUs in that room.
- self.pduid_pdu
- .iter_from(&prefix, false)
- .filter(|(k, _)| k.starts_with(&prefix))
- .map(|(_, pdu)| {
- serde_json::from_slice(&pdu)
- .map_err(|_| Error::bad_database("Invalid first PDU in db."))
- .map(Arc::new)
- })
- .next()
- .transpose()
+ self.db.first_pdu_in_room(room_id)
}
- */
#[tracing::instrument(skip(self))]
pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
@@ -681,7 +662,8 @@ impl Service {
/// Append the incoming event setting the state snapshot to the state from the
/// server that sent the event.
#[tracing::instrument(skip_all)]
- fn append_incoming_pdu<'a>(
+ pub fn append_incoming_pdu<'a>(
+ &self,
pdu: &PduEvent,
pdu_json: CanonicalJsonObject,
new_room_leaves: impl IntoIterator<Item = &'a EventId> + Clone + Debug,
diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs
index 7c7dfae..394a550 100644
--- a/src/service/rooms/user/mod.rs
+++ b/src/service/rooms/user/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{RoomId, UserId};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs
index 8ab557f..fde251b 100644
--- a/src/service/sending/mod.rs
+++ b/src/service/sending/mod.rs
@@ -448,14 +448,6 @@ impl Service {
Ok(())
}
- #[tracing::instrument(skip(keys))]
- fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
- // We only hash the pdu's event ids, not the whole pdu
- let bytes = keys.join(&0xff);
- let hash = digest::digest(&digest::SHA256, &bytes);
- hash.as_ref().to_owned()
- }
-
/// Cleanup event data
/// Used for instance after we remove an appservice registration
///
diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs
index a9c516c..d7066e2 100644
--- a/src/service/transaction_ids/mod.rs
+++ b/src/service/transaction_ids/mod.rs
@@ -1,11 +1,13 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{UserId, DeviceId, TransactionId};
use crate::Result;
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs
index 01c0d2f..73b2273 100644
--- a/src/service/uiaa/mod.rs
+++ b/src/service/uiaa/mod.rs
@@ -1,4 +1,6 @@
mod data;
+use std::sync::Arc;
+
pub use data::Data;
use ruma::{api::client::{uiaa::{UiaaInfo, IncomingAuthData, IncomingPassword, AuthType, IncomingUserIdentifier}, error::ErrorKind}, DeviceId, UserId, signatures::CanonicalJsonValue};
@@ -7,7 +9,7 @@ use tracing::error;
use crate::{Result, utils, Error, services, api::client_server::SESSION_ID_LENGTH};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs
index b13ae1f..2cf1876 100644
--- a/src/service/users/mod.rs
+++ b/src/service/users/mod.rs
@@ -1,5 +1,5 @@
mod data;
-use std::{collections::BTreeMap, mem};
+use std::{collections::BTreeMap, mem, sync::Arc};
pub use data::Data;
use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, DeviceKeyAlgorithm, UInt, events::AnyToDeviceEvent, api::client::{device::Device, filter::IncomingFilterDefinition, error::ErrorKind}, RoomAliasId};
@@ -7,7 +7,7 @@ use ruma::{UserId, MxcUri, DeviceId, DeviceKeyId, serde::Raw, encryption::{OneTi
use crate::{Result, Error, services};
pub struct Service {
- db: Box<dyn Data>,
+ db: Arc<dyn Data>,
}
impl Service {
diff --git a/src/utils/mod.rs b/src/utils/mod.rs
index 734da2a..0ee3ae8 100644
--- a/src/utils/mod.rs
+++ b/src/utils/mod.rs
@@ -3,6 +3,7 @@ pub mod error;
use argon2::{Config, Variant};
use cmp::Ordering;
use rand::prelude::*;
+use ring::digest;
use ruma::serde::{try_from_json_map, CanonicalJsonError, CanonicalJsonObject};
use std::{
cmp, fmt,
@@ -59,7 +60,7 @@ pub fn random_string(length: usize) -> String {
}
/// Calculate a new hash for the given password
-pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
+pub fn calculate_password_hash(password: &str) -> Result<String, argon2::Error> {
let hashing_config = Config {
variant: Variant::Argon2id,
..Default::default()
@@ -69,6 +70,15 @@ pub fn calculate_hash(password: &str) -> Result<String, argon2::Error> {
argon2::hash_encoded(password.as_bytes(), salt.as_bytes(), &hashing_config)
}
+#[tracing::instrument(skip(keys))]
+pub fn calculate_hash(keys: &[&[u8]]) -> Vec<u8> {
+ // We only hash the pdu's event ids, not the whole pdu
+ let bytes = keys.join(&0xff);
+ let hash = digest::digest(&digest::SHA256, &bytes);
+ hash.as_ref().to_owned()
+}
+
+
pub fn common_elements(
mut iterators: impl Iterator<Item = impl Iterator<Item = Vec<u8>>>,
check_order: impl Fn(&[u8], &[u8]) -> Ordering,