diff options
Diffstat (limited to 'src/service/users/mod.rs')
-rw-r--r-- | src/service/users/mod.rs | 256 |
1 files changed, 247 insertions, 9 deletions
diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index 6be5c89..c345e56 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,20 +1,41 @@ mod data; -use std::{collections::BTreeMap, mem}; +use std::{ + collections::BTreeMap, + mem, + sync::{Arc, Mutex}, +}; pub use data::Data; use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, + api::client::{ + device::Device, + error::ErrorKind, + filter::FilterDefinition, + sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + }, encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, events::AnyToDeviceEvent, serde::Raw, DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, - OwnedUserId, RoomAliasId, UInt, UserId, + OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId, }; use crate::{services, Error, Result}; +pub struct SlidingSyncCache { + lists: BTreeMap<String, SyncRequestList>, + subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, + known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, bool>>, + extensions: ExtensionsConfig, +} + pub struct Service { pub db: &'static dyn Data, + pub connections: + Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>, } impl Service { @@ -23,6 +44,193 @@ impl Service { self.db.exists(user_id) } + pub fn forget_sync_request_connection( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + conn_id: String, + ) { + self.connections + .lock() + .unwrap() + .remove(&(user_id, device_id, conn_id)); + } + + pub fn update_sync_request_with_cache( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + request: &mut sync_events::v4::Request, + ) -> BTreeMap<String, BTreeMap<OwnedRoomId, bool>> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; + + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().unwrap(); + drop(cache); + + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort = cached_list.sort.clone(); + }; + if list.room_details.required_state.is_empty() { + list.room_details.required_state = + cached_list.room_details.required_state.clone(); + }; + list.room_details.timeline_limit = list + .room_details + .timeline_limit + .or(cached_list.room_details.timeline_limit); + list.include_old_rooms = list + .include_old_rooms + .clone() + .or(cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = + list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = + list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = list_filters + .room_name_like + .clone() + .or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + } + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (_, _) => {} + } + if list.bump_event_types.is_empty() { + list.bump_event_types = cached_list.bump_event_types.clone(); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } + + cached + .subscriptions + .extend(request.room_subscriptions.clone().into_iter()); + request + .room_subscriptions + .extend(cached.subscriptions.clone().into_iter()); + + request.extensions.e2ee.enabled = request + .extensions + .e2ee + .enabled + .or(cached.extensions.e2ee.enabled); + + request.extensions.to_device.enabled = request + .extensions + .to_device + .enabled + .or(cached.extensions.to_device.enabled); + + request.extensions.account_data.enabled = request + .extensions + .account_data + .enabled + .or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = request + .extensions + .account_data + .lists + .clone() + .or(cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = request + .extensions + .account_data + .rooms + .clone() + .or(cached.extensions.account_data.rooms.clone()); + + cached.extensions = request.extensions.clone(); + + cached.known_rooms.clone() + } + + pub fn update_sync_subscriptions( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + conn_id: String, + subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>, + ) { + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().unwrap(); + drop(cache); + + cached.subscriptions = subscriptions; + } + + pub fn update_sync_known_rooms( + &self, + user_id: OwnedUserId, + device_id: OwnedDeviceId, + conn_id: String, + list_id: String, + new_cached_rooms: BTreeMap<OwnedRoomId, bool>, + ) { + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone( + cache + .entry((user_id, device_id, conn_id)) + .or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + }), + ); + let cached = &mut cached.lock().unwrap(); + drop(cache); + + cached.known_rooms.insert(list_id, new_cached_rooms); + } + /// Check if account is deactivated pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { self.db.is_deactivated(user_id) @@ -190,9 +398,15 @@ impl Service { master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>, user_signing_key: &Option<Raw<CrossSigningKey>>, + notify: bool, ) -> Result<()> { - self.db - .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key) + self.db.add_cross_signing_keys( + user_id, + master_key, + self_signing_key, + user_signing_key, + notify, + ) } pub fn sign_key( @@ -226,20 +440,43 @@ impl Service { self.db.get_device_keys(user_id, device_id) } + pub fn parse_master_key( + &self, + user_id: &UserId, + master_key: &Raw<CrossSigningKey>, + ) -> Result<(Vec<u8>, CrossSigningKey)> { + self.db.parse_master_key(user_id, master_key) + } + + pub fn get_key( + &self, + key: &[u8], + sender_user: Option<&UserId>, + user_id: &UserId, + allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result<Option<Raw<CrossSigningKey>>> { + self.db + .get_key(key, sender_user, user_id, allowed_signatures) + } + pub fn get_master_key( &self, + sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { - self.db.get_master_key(user_id, allowed_signatures) + self.db + .get_master_key(sender_user, user_id, allowed_signatures) } pub fn get_self_signing_key( &self, + sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, ) -> Result<Option<Raw<CrossSigningKey>>> { - self.db.get_self_signing_key(user_id, allowed_signatures) + self.db + .get_self_signing_key(sender_user, user_id, allowed_signatures) } pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { @@ -342,6 +579,7 @@ impl Service { /// Ensure that a user only sees signatures from themselves and the target user pub fn clean_signatures<F: Fn(&UserId) -> bool>( cross_signing_key: &mut serde_json::Value, + sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, ) -> Result<(), Error> { @@ -355,9 +593,9 @@ pub fn clean_signatures<F: Fn(&UserId) -> bool>( for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) { - let id = <&UserId>::try_from(user.as_str()) + let sid = <&UserId>::try_from(user.as_str()) .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if id == user_id || allowed_signatures(id) { + if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { signatures.insert(user, signature); } } |