From 9d49d599f3a1d5da535b71f2f8e4986c25b997e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timo=20K=C3=B6sters?= Date: Sun, 2 Jul 2023 16:06:54 +0200 Subject: feat: space hierarchies --- src/api/client_server/context.rs | 6 +- src/api/client_server/message.rs | 7 +- src/api/client_server/mod.rs | 2 + src/api/client_server/search.rs | 3 +- src/api/client_server/space.rs | 34 +++ src/api/server_server.rs | 2 +- src/main.rs | 4 +- src/service/mod.rs | 5 +- src/service/pdu.rs | 19 +- src/service/rooms/mod.rs | 2 + src/service/rooms/spaces/mod.rs | 436 +++++++++++++++++++++++++++++++++++++++ 11 files changed, 503 insertions(+), 17 deletions(-) create mode 100644 src/api/client_server/space.rs create mode 100644 src/service/rooms/spaces/mod.rs (limited to 'src') diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index e70f9f1..8e193e6 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -3,7 +3,7 @@ use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, }; -use std::{collections::HashSet, convert::TryFrom}; +use std::collections::HashSet; use tracing::error; /// # `GET /_matrix/client/r0/rooms/{roomId}/context` @@ -70,9 +70,7 @@ pub async fn get_context_route( } // Use limit with maximum 100 - let limit = usize::try_from(body.limit) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Limit value is invalid."))? - .min(100); + let limit = u64::from(body.limit).min(100) as usize; let base_event = base_event.to_room_event(); diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index dc2d994..750e030 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -133,12 +133,7 @@ pub async fn get_message_events_route( from, )?; - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .try_into() - .map_or(10_usize, |l: u32| l as usize) - .min(100); + let limit = u64::from(body.limit).min(100) as usize; let next_token; diff --git a/src/api/client_server/mod.rs b/src/api/client_server/mod.rs index 2ab3a98..54c99aa 100644 --- a/src/api/client_server/mod.rs +++ b/src/api/client_server/mod.rs @@ -21,6 +21,7 @@ mod report; mod room; mod search; mod session; +mod space; mod state; mod sync; mod tag; @@ -55,6 +56,7 @@ pub use report::*; pub use room::*; pub use search::*; pub use session::*; +pub use space::*; pub use state::*; pub use sync::*; pub use tag::*; diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index fe69e7c..e9fac36 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -31,7 +31,8 @@ pub async fn search_events_route( .collect() }); - let limit = filter.limit.map_or(10, |l| u64::from(l) as usize); + // Use limit or else 10, with maximum 100 + let limit = filter.limit.map_or(10, u64::from).min(100) as usize; let mut searches = Vec::new(); diff --git a/src/api/client_server/space.rs b/src/api/client_server/space.rs new file mode 100644 index 0000000..e2ea8c3 --- /dev/null +++ b/src/api/client_server/space.rs @@ -0,0 +1,34 @@ +use crate::{services, Result, Ruma}; +use ruma::api::client::space::get_hierarchy; + +/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`` +/// +/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space. +pub async fn get_hierarchy_route( + body: Ruma, +) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + + let skip = body + .from + .as_ref() + .and_then(|s| s.parse::().ok()) + .unwrap_or(0); + + let limit = body.limit.map_or(10, u64::from).min(100) as usize; + + let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself + + services() + .rooms + .spaces + .get_hierarchy( + sender_user, + &body.room_id, + limit, + skip, + max_depth, + body.suggested_only, + ) + .await +} diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 5e218be..adb5f1f 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -151,7 +151,7 @@ where .try_into_http_request::>( &actual_destination_str, SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], + &[MatrixVersion::V1_4], ) .map_err(|e| { warn!( diff --git a/src/main.rs b/src/main.rs index f9f88f4..3f14ca8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,8 @@ rust_2018_idioms, unused_qualifications, clippy::cloned_instead_of_copied, - clippy::str_to_string + clippy::str_to_string, + clippy::future_not_send )] #![allow(clippy::suspicious_else_formatting)] #![deny(clippy::dbg_macro)] @@ -386,6 +387,7 @@ fn routes() -> Router { .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route) .ruma_route(client_server::get_relating_events_with_rel_type_route) .ruma_route(client_server::get_relating_events_route) + .ruma_route(client_server::get_hierarchy_route) .ruma_route(server_server::get_server_version_route) .route( "/_matrix/key/v2/server", diff --git a/src/service/mod.rs b/src/service/mod.rs index 7a2bb64..dfdc5a6 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -90,7 +90,7 @@ impl Services { state_compressor: rooms::state_compressor::Service { db, stateinfo_cache: Mutex::new(LruCache::new( - (1000.0 * config.conduit_cache_capacity_modifier) as usize, + (300.0 * config.conduit_cache_capacity_modifier) as usize, )), }, timeline: rooms::timeline::Service { @@ -98,6 +98,9 @@ impl Services { lasttimelinecount_cache: Mutex::new(HashMap::new()), }, threads: rooms::threads::Service { db }, + spaces: rooms::spaces::Service { + roomid_spacechunk_cache: Mutex::new(LruCache::new(200)), + }, user: rooms::user::Service { db }, }, transaction_ids: transaction_ids::Service { db }, diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 9d284c0..d24e174 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -1,9 +1,9 @@ use crate::Error; use ruma::{ events::{ - room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyMessageLikeEvent, - AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, - AnyTimelineEvent, StateEvent, TimelineEventType, + room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, + AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, + AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, }, serde::Raw, state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, @@ -248,6 +248,19 @@ impl PduEvent { serde_json::from_value(json).expect("Raw::from_value always works") } + #[tracing::instrument(skip(self))] + pub fn to_stripped_spacechild_state_event(&self) -> Raw { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + "origin_server_ts": self.origin_server_ts, + }); + + serde_json::from_value(json).expect("Raw::from_value always works") + } + #[tracing::instrument(skip(self))] pub fn to_member_event(&self) -> Raw> { let mut json = json!({ diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index 61304d1..f073984 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -9,6 +9,7 @@ pub mod outlier; pub mod pdu_metadata; pub mod search; pub mod short; +pub mod spaces; pub mod state; pub mod state_accessor; pub mod state_cache; @@ -56,5 +57,6 @@ pub struct Service { pub state_compressor: state_compressor::Service, pub timeline: timeline::Service, pub threads: threads::Service, + pub spaces: spaces::Service, pub user: user::Service, } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs new file mode 100644 index 0000000..76ba6c5 --- /dev/null +++ b/src/service/rooms/spaces/mod.rs @@ -0,0 +1,436 @@ +use std::sync::{Arc, Mutex}; + +use lru_cache::LruCache; +use ruma::{ + api::{ + client::{ + error::ErrorKind, + space::{get_hierarchy, SpaceHierarchyRoomsChunk, SpaceRoomJoinRule}, + }, + federation, + }, + directory::PublicRoomJoinRule, + events::{ + room::{ + avatar::RoomAvatarEventContent, + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + name::RoomNameEventContent, + topic::RoomTopicEventContent, + }, + StateEventType, + }, + OwnedRoomId, RoomId, UserId, +}; + +use tracing::{debug, error, warn}; + +use crate::{services, Error, PduEvent, Result}; + +pub struct CachedSpaceChunk { + chunk: SpaceHierarchyRoomsChunk, + children: Vec, + join_rule: JoinRule, +} + +pub struct Service { + pub roomid_spacechunk_cache: Mutex>>, +} + +impl Service { + pub async fn get_hierarchy( + &self, + sender_user: &UserId, + room_id: &RoomId, + limit: usize, + skip: usize, + max_depth: usize, + suggested_only: bool, + ) -> Result { + let mut left_to_skip = skip; + + let mut rooms_in_path = Vec::new(); + let mut stack = vec![vec![room_id.to_owned()]]; + let mut results = Vec::new(); + + while let Some(current_room) = { + while stack.last().map_or(false, |s| s.is_empty()) { + stack.pop(); + } + if !stack.is_empty() { + stack.last_mut().and_then(|s| s.pop()) + } else { + None + } + } { + rooms_in_path.push(current_room.clone()); + if results.len() >= limit { + break; + } + + if let Some(cached) = self + .roomid_spacechunk_cache + .lock() + .unwrap() + .get_mut(¤t_room.to_owned()) + .as_ref() + { + if let Some(cached) = cached { + if let Some(_join_rule) = + self.handle_join_rule(&cached.join_rule, sender_user, ¤t_room)? + { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(cached.chunk.clone()); + } + if rooms_in_path.len() < max_depth { + stack.push(cached.children.clone()); + } + } + } + continue; + } + + if let Some(current_shortstatehash) = services() + .rooms + .state + .get_room_shortstatehash(¤t_room)? + { + let state = services() + .rooms + .state_accessor + .state_full_ids(current_shortstatehash) + .await?; + + let mut children_ids = Vec::new(); + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = + services().rooms.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } + if let Ok(room_id) = OwnedRoomId::try_from(state_key) { + children_ids.push(room_id); + children_pdus.push(services().rooms.timeline.get_pdu(&id)?.ok_or_else( + || Error::bad_database("Event in space state not found"), + )?); + } + } + + // TODO: Sort children + children_ids.reverse(); + + let chunk = self.get_room_chunk(sender_user, ¤t_room, children_pdus); + if let Ok(chunk) = chunk { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(chunk.clone()); + } + let join_rule = services() + .rooms + .state_accessor + .room_state_get(¤t_room, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| c.join_rule) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or(JoinRule::Invite); + + self.roomid_spacechunk_cache.lock().unwrap().insert( + current_room.clone(), + Some(CachedSpaceChunk { + chunk, + children: children_ids.clone(), + join_rule, + }), + ); + } + + if rooms_in_path.len() < max_depth { + stack.push(children_ids); + } + } else { + let server = current_room.server_name(); + if server == services().globals.server_name() { + continue; + } + if !results.is_empty() { + // Early return so the client can see some data already + break; + } + warn!("Asking {server} for /hierarchy"); + if let Ok(response) = services() + .sending + .send_federation_request( + &server, + federation::space::get_hierarchy::v1::Request { + room_id: current_room.to_owned(), + suggested_only, + }, + ) + .await + { + warn!("Got response from {server} for /hierarchy\n{response:?}"); + let join_rule = self.translate_pjoinrule(&response.room.join_rule)?; + let chunk = SpaceHierarchyRoomsChunk { + canonical_alias: response.room.canonical_alias, + name: response.room.name, + num_joined_members: response.room.num_joined_members, + room_id: response.room.room_id, + topic: response.room.topic, + world_readable: response.room.world_readable, + guest_can_join: response.room.guest_can_join, + avatar_url: response.room.avatar_url, + join_rule: self.translate_sjoinrule(&response.room.join_rule)?, + room_type: response.room.room_type, + children_state: response.room.children_state, + }; + let children = response + .children + .iter() + .map(|c| c.room_id.clone()) + .collect::>(); + + if let Some(_join_rule) = + self.handle_join_rule(&join_rule, sender_user, ¤t_room)? + { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(chunk.clone()); + } + if rooms_in_path.len() < max_depth { + stack.push(children.clone()); + } + } + + self.roomid_spacechunk_cache.lock().unwrap().insert( + current_room.clone(), + Some(CachedSpaceChunk { + chunk, + children, + join_rule, + }), + ); + + /* TODO: + for child in response.children { + roomid_spacechunk_cache.insert( + current_room.clone(), + CachedSpaceChunk { + chunk: child.chunk, + children, + join_rule, + }, + ); + } + */ + } else { + self.roomid_spacechunk_cache + .lock() + .unwrap() + .insert(current_room.clone(), None); + } + } + } + + Ok(get_hierarchy::v1::Response { + next_batch: if results.is_empty() { + None + } else { + Some((skip + results.len()).to_string()) + }, + rooms: results, + }) + } + + fn get_room_chunk( + &self, + sender_user: &UserId, + room_id: &RoomId, + children: Vec>, + ) -> Result { + Ok(SpaceHierarchyRoomsChunk { + canonical_alias: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomCanonicalAliasEventContent| c.alias) + .map_err(|_| { + Error::bad_database("Invalid canonical alias event in database.") + }) + })?, + name: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomName, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomNameEventContent| c.name) + .map_err(|_| Error::bad_database("Invalid room name event in database.")) + })?, + num_joined_members: services() + .rooms + .state_cache + .room_joined_count(&room_id)? + .unwrap_or_else(|| { + warn!("Room {} has no member count", room_id); + 0 + }) + .try_into() + .expect("user count should not be that big"), + room_id: room_id.to_owned(), + topic: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomTopic, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomTopicEventContent| Some(c.topic)) + .map_err(|_| Error::bad_database("Invalid room topic event in database.")) + })?, + world_readable: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility == HistoryVisibility::WorldReadable + }) + .map_err(|_| { + Error::bad_database( + "Invalid room history visibility event in database.", + ) + }) + })?, + guest_can_join: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomGuestAccessEventContent| { + c.guest_access == GuestAccess::CanJoin + }) + .map_err(|_| { + Error::bad_database("Invalid room guest access event in database.") + }) + })?, + avatar_url: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomAvatarEventContent| c.url) + .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + join_rule: { + let join_rule = services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| c.join_rule) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or(JoinRule::Invite); + self.handle_join_rule(&join_rule, sender_user, room_id)? + .ok_or_else(|| { + debug!("User is not allowed to see room {room_id}"); + // This error will be caught later + Error::BadRequest( + ErrorKind::Forbidden, + "User is not allowed to see the room", + ) + })? + }, + room_type: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + .map(|s| { + serde_json::from_str::(s.content.get()).map_err(|e| { + error!("Invalid room create event in database: {}", e); + Error::BadDatabase("Invalid room create event in database.") + }) + }) + .transpose()? + .and_then(|e| e.room_type), + children_state: children + .into_iter() + .map(|pdu| pdu.to_stripped_spacechild_state_event()) + .collect(), + }) + } + + fn translate_pjoinrule(&self, join_rule: &PublicRoomJoinRule) -> Result { + match join_rule { + PublicRoomJoinRule::Knock => Ok(JoinRule::Knock), + PublicRoomJoinRule::Public => Ok(JoinRule::Public), + _ => Err(Error::BadServerResponse("Unknown join rule")), + } + } + + fn translate_sjoinrule(&self, join_rule: &PublicRoomJoinRule) -> Result { + match join_rule { + PublicRoomJoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), + PublicRoomJoinRule::Public => Ok(SpaceRoomJoinRule::Public), + _ => Err(Error::BadServerResponse("Unknown join rule")), + } + } + + fn handle_join_rule( + &self, + join_rule: &JoinRule, + sender_user: &UserId, + room_id: &RoomId, + ) -> Result> { + match join_rule { + JoinRule::Public => Ok::<_, Error>(Some(SpaceRoomJoinRule::Public)), + JoinRule::Knock => Ok(Some(SpaceRoomJoinRule::Knock)), + JoinRule::Invite => { + if services() + .rooms + .state_cache + .is_joined(sender_user, &room_id)? + { + Ok(Some(SpaceRoomJoinRule::Invite)) + } else { + Ok(None) + } + } + JoinRule::Restricted(_r) => { + // TODO: Check rules + Ok(None) + } + JoinRule::KnockRestricted(_r) => { + // TODO: Check rules + Ok(None) + } + _ => Ok(None), + } + } +} -- cgit v1.2.3