summaryrefslogtreecommitdiff
path: root/src/service/users/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/service/users/mod.rs')
-rw-r--r--src/service/users/mod.rs256
1 files changed, 247 insertions, 9 deletions
diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs
index 6be5c89..c345e56 100644
--- a/src/service/users/mod.rs
+++ b/src/service/users/mod.rs
@@ -1,20 +1,41 @@
mod data;
-use std::{collections::BTreeMap, mem};
+use std::{
+ collections::BTreeMap,
+ mem,
+ sync::{Arc, Mutex},
+};
pub use data::Data;
use ruma::{
- api::client::{device::Device, error::ErrorKind, filter::FilterDefinition},
+ api::client::{
+ device::Device,
+ error::ErrorKind,
+ filter::FilterDefinition,
+ sync::sync_events::{
+ self,
+ v4::{ExtensionsConfig, SyncRequestList},
+ },
+ },
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
events::AnyToDeviceEvent,
serde::Raw,
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri,
- OwnedUserId, RoomAliasId, UInt, UserId,
+ OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId,
};
use crate::{services, Error, Result};
+pub struct SlidingSyncCache {
+ lists: BTreeMap<String, SyncRequestList>,
+ subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
+ known_rooms: BTreeMap<String, BTreeMap<OwnedRoomId, bool>>,
+ extensions: ExtensionsConfig,
+}
+
pub struct Service {
pub db: &'static dyn Data,
+ pub connections:
+ Mutex<BTreeMap<(OwnedUserId, OwnedDeviceId, String), Arc<Mutex<SlidingSyncCache>>>>,
}
impl Service {
@@ -23,6 +44,193 @@ impl Service {
self.db.exists(user_id)
}
+ pub fn forget_sync_request_connection(
+ &self,
+ user_id: OwnedUserId,
+ device_id: OwnedDeviceId,
+ conn_id: String,
+ ) {
+ self.connections
+ .lock()
+ .unwrap()
+ .remove(&(user_id, device_id, conn_id));
+ }
+
+ pub fn update_sync_request_with_cache(
+ &self,
+ user_id: OwnedUserId,
+ device_id: OwnedDeviceId,
+ request: &mut sync_events::v4::Request,
+ ) -> BTreeMap<String, BTreeMap<OwnedRoomId, bool>> {
+ let Some(conn_id) = request.conn_id.clone() else {
+ return BTreeMap::new();
+ };
+
+ let mut cache = self.connections.lock().unwrap();
+ let cached = Arc::clone(
+ cache
+ .entry((user_id, device_id, conn_id))
+ .or_insert_with(|| {
+ Arc::new(Mutex::new(SlidingSyncCache {
+ lists: BTreeMap::new(),
+ subscriptions: BTreeMap::new(),
+ known_rooms: BTreeMap::new(),
+ extensions: ExtensionsConfig::default(),
+ }))
+ }),
+ );
+ let cached = &mut cached.lock().unwrap();
+ drop(cache);
+
+ for (list_id, list) in &mut request.lists {
+ if let Some(cached_list) = cached.lists.get(list_id) {
+ if list.sort.is_empty() {
+ list.sort = cached_list.sort.clone();
+ };
+ if list.room_details.required_state.is_empty() {
+ list.room_details.required_state =
+ cached_list.room_details.required_state.clone();
+ };
+ list.room_details.timeline_limit = list
+ .room_details
+ .timeline_limit
+ .or(cached_list.room_details.timeline_limit);
+ list.include_old_rooms = list
+ .include_old_rooms
+ .clone()
+ .or(cached_list.include_old_rooms.clone());
+ match (&mut list.filters, cached_list.filters.clone()) {
+ (Some(list_filters), Some(cached_filters)) => {
+ list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm);
+ if list_filters.spaces.is_empty() {
+ list_filters.spaces = cached_filters.spaces;
+ }
+ list_filters.is_encrypted =
+ list_filters.is_encrypted.or(cached_filters.is_encrypted);
+ list_filters.is_invite =
+ list_filters.is_invite.or(cached_filters.is_invite);
+ if list_filters.room_types.is_empty() {
+ list_filters.room_types = cached_filters.room_types;
+ }
+ if list_filters.not_room_types.is_empty() {
+ list_filters.not_room_types = cached_filters.not_room_types;
+ }
+ list_filters.room_name_like = list_filters
+ .room_name_like
+ .clone()
+ .or(cached_filters.room_name_like);
+ if list_filters.tags.is_empty() {
+ list_filters.tags = cached_filters.tags;
+ }
+ if list_filters.not_tags.is_empty() {
+ list_filters.not_tags = cached_filters.not_tags;
+ }
+ }
+ (_, Some(cached_filters)) => list.filters = Some(cached_filters),
+ (_, _) => {}
+ }
+ if list.bump_event_types.is_empty() {
+ list.bump_event_types = cached_list.bump_event_types.clone();
+ };
+ }
+ cached.lists.insert(list_id.clone(), list.clone());
+ }
+
+ cached
+ .subscriptions
+ .extend(request.room_subscriptions.clone().into_iter());
+ request
+ .room_subscriptions
+ .extend(cached.subscriptions.clone().into_iter());
+
+ request.extensions.e2ee.enabled = request
+ .extensions
+ .e2ee
+ .enabled
+ .or(cached.extensions.e2ee.enabled);
+
+ request.extensions.to_device.enabled = request
+ .extensions
+ .to_device
+ .enabled
+ .or(cached.extensions.to_device.enabled);
+
+ request.extensions.account_data.enabled = request
+ .extensions
+ .account_data
+ .enabled
+ .or(cached.extensions.account_data.enabled);
+ request.extensions.account_data.lists = request
+ .extensions
+ .account_data
+ .lists
+ .clone()
+ .or(cached.extensions.account_data.lists.clone());
+ request.extensions.account_data.rooms = request
+ .extensions
+ .account_data
+ .rooms
+ .clone()
+ .or(cached.extensions.account_data.rooms.clone());
+
+ cached.extensions = request.extensions.clone();
+
+ cached.known_rooms.clone()
+ }
+
+ pub fn update_sync_subscriptions(
+ &self,
+ user_id: OwnedUserId,
+ device_id: OwnedDeviceId,
+ conn_id: String,
+ subscriptions: BTreeMap<OwnedRoomId, sync_events::v4::RoomSubscription>,
+ ) {
+ let mut cache = self.connections.lock().unwrap();
+ let cached = Arc::clone(
+ cache
+ .entry((user_id, device_id, conn_id))
+ .or_insert_with(|| {
+ Arc::new(Mutex::new(SlidingSyncCache {
+ lists: BTreeMap::new(),
+ subscriptions: BTreeMap::new(),
+ known_rooms: BTreeMap::new(),
+ extensions: ExtensionsConfig::default(),
+ }))
+ }),
+ );
+ let cached = &mut cached.lock().unwrap();
+ drop(cache);
+
+ cached.subscriptions = subscriptions;
+ }
+
+ pub fn update_sync_known_rooms(
+ &self,
+ user_id: OwnedUserId,
+ device_id: OwnedDeviceId,
+ conn_id: String,
+ list_id: String,
+ new_cached_rooms: BTreeMap<OwnedRoomId, bool>,
+ ) {
+ let mut cache = self.connections.lock().unwrap();
+ let cached = Arc::clone(
+ cache
+ .entry((user_id, device_id, conn_id))
+ .or_insert_with(|| {
+ Arc::new(Mutex::new(SlidingSyncCache {
+ lists: BTreeMap::new(),
+ subscriptions: BTreeMap::new(),
+ known_rooms: BTreeMap::new(),
+ extensions: ExtensionsConfig::default(),
+ }))
+ }),
+ );
+ let cached = &mut cached.lock().unwrap();
+ drop(cache);
+
+ cached.known_rooms.insert(list_id, new_cached_rooms);
+ }
+
/// Check if account is deactivated
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
self.db.is_deactivated(user_id)
@@ -190,9 +398,15 @@ impl Service {
master_key: &Raw<CrossSigningKey>,
self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>,
+ notify: bool,
) -> Result<()> {
- self.db
- .add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key)
+ self.db.add_cross_signing_keys(
+ user_id,
+ master_key,
+ self_signing_key,
+ user_signing_key,
+ notify,
+ )
}
pub fn sign_key(
@@ -226,20 +440,43 @@ impl Service {
self.db.get_device_keys(user_id, device_id)
}
+ pub fn parse_master_key(
+ &self,
+ user_id: &UserId,
+ master_key: &Raw<CrossSigningKey>,
+ ) -> Result<(Vec<u8>, CrossSigningKey)> {
+ self.db.parse_master_key(user_id, master_key)
+ }
+
+ pub fn get_key(
+ &self,
+ key: &[u8],
+ sender_user: Option<&UserId>,
+ user_id: &UserId,
+ allowed_signatures: &dyn Fn(&UserId) -> bool,
+ ) -> Result<Option<Raw<CrossSigningKey>>> {
+ self.db
+ .get_key(key, sender_user, user_id, allowed_signatures)
+ }
+
pub fn get_master_key(
&self,
+ sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
- self.db.get_master_key(user_id, allowed_signatures)
+ self.db
+ .get_master_key(sender_user, user_id, allowed_signatures)
}
pub fn get_self_signing_key(
&self,
+ sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: &dyn Fn(&UserId) -> bool,
) -> Result<Option<Raw<CrossSigningKey>>> {
- self.db.get_self_signing_key(user_id, allowed_signatures)
+ self.db
+ .get_self_signing_key(sender_user, user_id, allowed_signatures)
}
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
@@ -342,6 +579,7 @@ impl Service {
/// Ensure that a user only sees signatures from themselves and the target user
pub fn clean_signatures<F: Fn(&UserId) -> bool>(
cross_signing_key: &mut serde_json::Value,
+ sender_user: Option<&UserId>,
user_id: &UserId,
allowed_signatures: F,
) -> Result<(), Error> {
@@ -355,9 +593,9 @@ pub fn clean_signatures<F: Fn(&UserId) -> bool>(
for (user, signature) in
mem::replace(signatures, serde_json::Map::with_capacity(new_capacity))
{
- let id = <&UserId>::try_from(user.as_str())
+ let sid = <&UserId>::try_from(user.as_str())
.map_err(|_| Error::bad_database("Invalid user ID in database."))?;
- if id == user_id || allowed_signatures(id) {
+ if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) {
signatures.insert(user, signature);
}
}