diff options
author | Timo Kösters <timo@koesters.xyz> | 2022-09-06 23:15:09 +0200 |
---|---|---|
committer | Nyaaori <+@nyaaori.cat> | 2022-10-10 13:25:01 +0200 |
commit | 057f8364cc317dc8646043abd6c8ff3ef759625f (patch) | |
tree | f80bf450aa962947ab2651376768e021113a7ef6 /src/api | |
parent | 82e7f57b389d011bc8d80f9142f723b3cd1e1ad2 (diff) | |
download | conduit-057f8364cc317dc8646043abd6c8ff3ef759625f.zip |
fix: some compile time errors
Only 174 errors left!
Diffstat (limited to 'src/api')
33 files changed, 908 insertions, 1318 deletions
diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index ce122da..1f6e2c9 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,12 +1,11 @@ -use crate::{utils, Error, Result}; +use crate::{utils, Error, Result, services}; use bytes::BytesMut; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use std::{fmt::Debug, mem, time::Duration}; use tracing::warn; -#[tracing::instrument(skip(globals, request))] +#[tracing::instrument(skip(request))] pub(crate) async fn send_request<T: OutgoingRequest>( - globals: &crate::database::globals::Globals, registration: serde_yaml::Value, request: T, ) -> Result<T::IncomingResponse> @@ -46,7 +45,7 @@ where *reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); let url = reqwest_request.url().clone(); - let mut response = globals.default_client().execute(reqwest_request).await?; + let mut response = services().globals.default_client().execute(reqwest_request).await?; // reqwest::Response -> http::Response conversion let status = response.status(); diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index dc0782d..848bfaa 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -2,9 +2,7 @@ use std::sync::Arc; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use crate::{ - database::{admin::make_user_admin, DatabaseGuard}, - pdu::PduBuilder, - utils, Database, Error, Result, Ruma, + utils, Error, Result, Ruma, services, }; use ruma::{ api::client::{ @@ -42,15 +40,14 @@ const RANDOM_USER_ID_LENGTH: usize = 10; /// /// Note: This will not reserve the username, so the username might become invalid when trying to register pub async fn get_register_available_route( - db: DatabaseGuard, body: Ruma<get_username_availability::v3::IncomingRequest>, ) -> Result<get_username_availability::v3::Response> { // Validate user id let user_id = - UserId::parse_with_server_name(body.username.to_lowercase(), db.globals.server_name()) + UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) .ok() .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == db.globals.server_name() + !user_id.is_historical() && user_id.server_name() == services().globals.server_name() }) .ok_or(Error::BadRequest( ErrorKind::InvalidUsername, @@ -58,7 +55,7 @@ pub async fn get_register_available_route( ))?; // Check if username is creative enough - if db.users.exists(&user_id)? { + if services().users.exists(&user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -85,10 +82,9 @@ pub async fn get_register_available_route( /// - Creates a new account and populates it with default account data /// - If `inhibit_login` is false: Creates a device and returns device id and access_token pub async fn register_route( - db: DatabaseGuard, body: Ruma<register::v3::IncomingRequest>, ) -> Result<register::v3::Response> { - if !db.globals.allow_registration() && !body.from_appservice { + if !services().globals.allow_registration() && !body.from_appservice { return Err(Error::BadRequest( ErrorKind::Forbidden, "Registration has been disabled.", @@ -100,17 +96,17 @@ pub async fn register_route( let user_id = match (&body.username, is_guest) { (Some(username), false) => { let proposed_user_id = - UserId::parse_with_server_name(username.to_lowercase(), db.globals.server_name()) + UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) .ok() .filter(|user_id| { !user_id.is_historical() - && user_id.server_name() == db.globals.server_name() + && user_id.server_name() == services().globals.server_name() }) .ok_or(Error::BadRequest( ErrorKind::InvalidUsername, "Username is invalid.", ))?; - if db.users.exists(&proposed_user_id)? { + if services().users.exists(&proposed_user_id)? { return Err(Error::BadRequest( ErrorKind::UserInUse, "Desired user ID is already taken.", @@ -121,10 +117,10 @@ pub async fn register_route( _ => loop { let proposed_user_id = UserId::parse_with_server_name( utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap(); - if !db.users.exists(&proposed_user_id)? { + if !services().users.exists(&proposed_user_id)? { break proposed_user_id; } }, @@ -143,14 +139,12 @@ pub async fn register_route( if !body.from_appservice { if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( - &UserId::parse_with_server_name("", db.globals.server_name()) + let (worked, uiaainfo) = services().uiaa.try_auth( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -158,8 +152,8 @@ pub async fn register_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa.create( - &UserId::parse_with_server_name("", db.globals.server_name()) + services().uiaa.create( + &UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid"), "".into(), &uiaainfo, @@ -178,15 +172,15 @@ pub async fn register_route( }; // Create user - db.users.create(&user_id, password)?; + services().users.create(&user_id, password)?; // Default to pretty displayname let displayname = format!("{} ⚡️", user_id.localpart()); - db.users + services().users .set_displayname(&user_id, Some(displayname.clone()))?; // Initial account data - db.account_data.update( + services().account_data.update( None, &user_id, GlobalAccountDataEventType::PushRules.to_string().into(), @@ -195,7 +189,6 @@ pub async fn register_route( global: push::Ruleset::server_default(&user_id), }, }, - &db.globals, )?; // Inhibit login does not work for guests @@ -219,7 +212,7 @@ pub async fn register_route( let token = utils::random_string(TOKEN_LENGTH); // Create device for this account - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -227,7 +220,7 @@ pub async fn register_route( )?; info!("New user {} registered on this server.", user_id); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "New user {} registered on this server.", user_id @@ -235,14 +228,12 @@ pub async fn register_route( // If this is the first real user, grant them admin privileges // Note: the server user, @conduit:servername, is generated first - if db.users.count()? == 2 { - make_user_admin(&db, &user_id, displayname).await?; + if services().users.count()? == 2 { + services().admin.make_user_admin(&user_id, displayname).await?; warn!("Granting {} admin privileges as the first user", user_id); } - db.flush()?; - Ok(register::v3::Response { access_token: Some(token), user_id, @@ -265,7 +256,6 @@ pub async fn register_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn change_password_route( - db: DatabaseGuard, body: Ruma<change_password::v3::IncomingRequest>, ) -> Result<change_password::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -282,13 +272,11 @@ pub async fn change_password_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -296,32 +284,30 @@ pub async fn change_password_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users + services().users .set_password(sender_user, Some(&body.new_password))?; if body.logout_devices { // Logout all devices except the current one - for id in db + for id in services() .users .all_device_ids(sender_user) .filter_map(|id| id.ok()) .filter(|id| id != sender_device) { - db.users.remove_device(sender_user, &id)?; + services().users.remove_device(sender_user, &id)?; } } - db.flush()?; - info!("User {} changed their password.", sender_user); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {} changed their password.", sender_user @@ -336,7 +322,6 @@ pub async fn change_password_route( /// /// Note: Also works for Application Services pub async fn whoami_route( - db: DatabaseGuard, body: Ruma<whoami::v3::Request>, ) -> Result<whoami::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -345,7 +330,7 @@ pub async fn whoami_route( Ok(whoami::v3::Response { user_id: sender_user.clone(), device_id, - is_guest: db.users.is_deactivated(&sender_user)?, + is_guest: services().users.is_deactivated(&sender_user)?, }) } @@ -360,7 +345,6 @@ pub async fn whoami_route( /// - Triggers device list updates /// - Removes ability to log in again pub async fn deactivate_route( - db: DatabaseGuard, body: Ruma<deactivate::v3::IncomingRequest>, ) -> Result<deactivate::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -377,13 +361,11 @@ pub async fn deactivate_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -391,7 +373,7 @@ pub async fn deactivate_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -399,20 +381,18 @@ pub async fn deactivate_route( } // Make the user leave all rooms before deactivation - db.rooms.leave_all_rooms(&sender_user, &db).await?; + services().rooms.leave_all_rooms(&sender_user).await?; // Remove devices and mark account as deactivated - db.users.deactivate_account(sender_user)?; + services().users.deactivate_account(sender_user)?; info!("User {} deactivated their account.", sender_user); - db.admin + services().admin .send_message(RoomMessageEventContent::notice_plain(format!( "User {} deactivated their account.", sender_user ))); - db.flush()?; - Ok(deactivate::v3::Response { id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, }) diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 90e9d2c..7aa5fb2 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use regex::Regex; use ruma::{ api::{ @@ -16,24 +16,21 @@ use ruma::{ /// /// Creates a new room alias on this server. pub async fn create_alias_route( - db: DatabaseGuard, body: Ruma<create_alias::v3::IncomingRequest>, ) -> Result<create_alias::v3::Response> { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - if db.rooms.id_from_alias(&body.room_alias)?.is_some() { + if services().rooms.id_from_alias(&body.room_alias)?.is_some() { return Err(Error::Conflict("Alias already exists.")); } - db.rooms - .set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; - - db.flush()?; + services().rooms + .set_alias(&body.room_alias, Some(&body.room_id))?; Ok(create_alias::v3::Response::new()) } @@ -45,22 +42,19 @@ pub async fn create_alias_route( /// - TODO: additional access control checks /// - TODO: Update canonical alias event pub async fn delete_alias_route( - db: DatabaseGuard, body: Ruma<delete_alias::v3::IncomingRequest>, ) -> Result<delete_alias::v3::Response> { - if body.room_alias.server_name() != db.globals.server_name() { + if body.room_alias.server_name() != services().globals.server_name() { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Alias is from another server.", )); } - db.rooms.set_alias(&body.room_alias, None, &db.globals)?; + services().rooms.set_alias(&body.room_alias, None)?; // TODO: update alt_aliases? - db.flush()?; - Ok(delete_alias::v3::Response::new()) } @@ -70,21 +64,18 @@ pub async fn delete_alias_route( /// /// - TODO: Suggest more servers to join via pub async fn get_alias_route( - db: DatabaseGuard, body: Ruma<get_alias::v3::IncomingRequest>, ) -> Result<get_alias::v3::Response> { - get_alias_helper(&db, &body.room_alias).await + get_alias_helper(&body.room_alias).await } pub(crate) async fn get_alias_helper( - db: &Database, room_alias: &RoomAliasId, ) -> Result<get_alias::v3::Response> { - if room_alias.server_name() != db.globals.server_name() { - let response = db + if room_alias.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, room_alias.server_name(), federation::query::get_room_information::v1::Request { room_alias }, ) @@ -97,10 +88,10 @@ pub(crate) async fn get_alias_helper( } let mut room_id = None; - match db.rooms.id_from_alias(room_alias)? { + match services().rooms.id_from_alias(room_alias)? { Some(r) => room_id = Some(r), None => { - for (_id, registration) in db.appservice.all()? { + for (_id, registration) in services().appservice.all()? { let aliases = registration .get("namespaces") .and_then(|ns| ns.get("aliases")) @@ -115,17 +106,16 @@ pub(crate) async fn get_alias_helper( if aliases .iter() .any(|aliases| aliases.is_match(room_alias.as_str())) - && db + && services() .sending .send_appservice_request( - &db.globals, registration, appservice::query::query_room_alias::v1::Request { room_alias }, ) .await .is_ok() { - room_id = Some(db.rooms.id_from_alias(room_alias)?.ok_or_else(|| { + room_id = Some(services().rooms.id_from_alias(room_alias)?.ok_or_else(|| { Error::bad_config("Appservice lied to us. Room does not exist.") })?); break; @@ -146,6 +136,6 @@ pub(crate) async fn get_alias_helper( Ok(get_alias::v3::Response::new( room_id, - vec![db.globals.server_name().to_owned()], + vec![services().globals.server_name().to_owned()], )) } diff --git a/src/api/client_server/backup.rs b/src/api/client_server/backup.rs index 067f20c..e413893 100644 --- a/src/api/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ backup::{ add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, @@ -14,15 +14,12 @@ use ruma::api::client::{ /// /// Creates a new backup. pub async fn create_backup_version_route( - db: DatabaseGuard, body: Ruma<create_backup_version::v3::Request>, ) -> Result<create_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = db + let version = services() .key_backups - .create_backup(sender_user, &body.algorithm, &db.globals)?; - - db.flush()?; + .create_backup(sender_user, &body.algorithm)?; Ok(create_backup_version::v3::Response { version }) } @@ -31,14 +28,11 @@ pub async fn create_backup_version_route( /// /// Update information about an existing backup. Only `auth_data` can be modified. pub async fn update_backup_version_route( - db: DatabaseGuard, body: Ruma<update_backup_version::v3::IncomingRequest>, ) -> Result<update_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups - .update_backup(sender_user, &body.version, &body.algorithm, &db.globals)?; - - db.flush()?; + services().key_backups + .update_backup(sender_user, &body.version, &body.algorithm)?; Ok(update_backup_version::v3::Response {}) } @@ -47,13 +41,12 @@ pub async fn update_backup_version_route( /// /// Get information about the latest backup version. pub async fn get_latest_backup_info_route( - db: DatabaseGuard, body: Ruma<get_latest_backup_info::v3::Request>, ) -> Result<get_latest_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let (version, algorithm) = - db.key_backups + services().key_backups .get_latest_backup(sender_user)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -62,8 +55,8 @@ pub async fn get_latest_backup_info_route( Ok(get_latest_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &version)?, + count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &version)?, version, }) } @@ -72,11 +65,10 @@ pub async fn get_latest_backup_info_route( /// /// Get information about an existing backup. pub async fn get_backup_info_route( - db: DatabaseGuard, body: Ruma<get_backup_info::v3::IncomingRequest>, ) -> Result<get_backup_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let algorithm = db + let algorithm = services() .key_backups .get_backup(sender_user, &body.version)? .ok_or(Error::BadRequest( @@ -86,8 +78,8 @@ pub async fn get_backup_info_route( Ok(get_backup_info::v3::Response { algorithm, - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, version: body.version.to_owned(), }) } @@ -98,14 +90,11 @@ pub async fn get_backup_info_route( /// /// - Deletes both information about the backup, as well as all key data related to the backup pub async fn delete_backup_version_route( - db: DatabaseGuard, body: Ruma<delete_backup_version::v3::IncomingRequest>, ) -> Result<delete_backup_version::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_backup(sender_user, &body.version)?; - - db.flush()?; + services().key_backups.delete_backup(sender_user, &body.version)?; Ok(delete_backup_version::v3::Response {}) } @@ -118,13 +107,12 @@ pub async fn delete_backup_version_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_route( - db: DatabaseGuard, body: Ruma<add_backup_keys::v3::IncomingRequest>, ) -> Result<add_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -137,22 +125,19 @@ pub async fn add_backup_keys_route( for (room_id, room) in &body.rooms { for (session_id, key_data) in &room.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, room_id, session_id, key_data, - &db.globals, )? } } - db.flush()?; - Ok(add_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -164,13 +149,12 @@ pub async fn add_backup_keys_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma<add_backup_keys_for_room::v3::IncomingRequest>, ) -> Result<add_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -182,21 +166,18 @@ pub async fn add_backup_keys_for_room_route( } for (session_id, key_data) in &body.sessions { - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, session_id, key_data, - &db.globals, )? } - db.flush()?; - Ok(add_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -208,13 +189,12 @@ pub async fn add_backup_keys_for_room_route( /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma<add_backup_keys_for_session::v3::IncomingRequest>, ) -> Result<add_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if Some(&body.version) - != db + != services() .key_backups .get_latest_backup_version(sender_user)? .as_ref() @@ -225,20 +205,17 @@ pub async fn add_backup_keys_for_session_route( )); } - db.key_backups.add_key( + services().key_backups.add_key( sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data, - &db.globals, )?; - db.flush()?; - Ok(add_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -246,12 +223,11 @@ pub async fn add_backup_keys_for_session_route( /// /// Retrieves all keys from the backup. pub async fn get_backup_keys_route( - db: DatabaseGuard, body: Ruma<get_backup_keys::v3::IncomingRequest>, ) -> Result<get_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = db.key_backups.get_all(sender_user, &body.version)?; + let rooms = services().key_backups.get_all(sender_user, &body.version)?; Ok(get_backup_keys::v3::Response { rooms }) } @@ -260,12 +236,11 @@ pub async fn get_backup_keys_route( /// /// Retrieves all keys from the backup for a given room. pub async fn get_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma<get_backup_keys_for_room::v3::IncomingRequest>, ) -> Result<get_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = db + let sessions = services() .key_backups .get_room(sender_user, &body.version, &body.room_id)?; @@ -276,12 +251,11 @@ pub async fn get_backup_keys_for_room_route( /// /// Retrieves a key from the backup. pub async fn get_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma<get_backup_keys_for_session::v3::IncomingRequest>, ) -> Result<get_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = db + let key_data = services() .key_backups .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .ok_or(Error::BadRequest( @@ -296,18 +270,15 @@ pub async fn get_backup_keys_for_session_route( /// /// Delete the keys from the backup. pub async fn delete_backup_keys_route( - db: DatabaseGuard, body: Ruma<delete_backup_keys::v3::IncomingRequest>, ) -> Result<delete_backup_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups.delete_all_keys(sender_user, &body.version)?; - - db.flush()?; + services().key_backups.delete_all_keys(sender_user, &body.version)?; Ok(delete_backup_keys::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -315,19 +286,16 @@ pub async fn delete_backup_keys_route( /// /// Delete the keys from the backup for a given room. pub async fn delete_backup_keys_for_room_route( - db: DatabaseGuard, body: Ruma<delete_backup_keys_for_room::v3::IncomingRequest>, ) -> Result<delete_backup_keys_for_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups + services().key_backups .delete_room_keys(sender_user, &body.version, &body.room_id)?; - db.flush()?; - Ok(delete_backup_keys_for_room::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } @@ -335,18 +303,15 @@ pub async fn delete_backup_keys_for_room_route( /// /// Delete a key from the backup. pub async fn delete_backup_keys_for_session_route( - db: DatabaseGuard, body: Ruma<delete_backup_keys_for_session::v3::IncomingRequest>, ) -> Result<delete_backup_keys_for_session::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.key_backups + services().key_backups .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; - db.flush()?; - Ok(delete_backup_keys_for_session::v3::Response { - count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), - etag: db.key_backups.get_etag(sender_user, &body.version)?, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, }) } diff --git a/src/api/client_server/capabilities.rs b/src/api/client_server/capabilities.rs index 417ad29..e4283b7 100644 --- a/src/api/client_server/capabilities.rs +++ b/src/api/client_server/capabilities.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::api::client::discovery::get_capabilities::{ self, Capabilities, RoomVersionStability, RoomVersionsCapability, }; @@ -8,26 +8,25 @@ use std::collections::BTreeMap; /// /// Get information on the supported feature set and other relevent capabilities of this server. pub async fn get_capabilities_route( - db: DatabaseGuard, _body: Ruma<get_capabilities::v3::IncomingRequest>, ) -> Result<get_capabilities::v3::Response> { let mut available = BTreeMap::new(); - if db.globals.allow_unstable_room_versions() { - for room_version in &db.globals.unstable_room_versions { + if services().globals.allow_unstable_room_versions() { + for room_version in &services().globals.unstable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } } else { - for room_version in &db.globals.unstable_room_versions { + for room_version in &services().globals.unstable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Unstable); } } - for room_version in &db.globals.stable_room_versions { + for room_version in &services().globals.stable_room_versions { available.insert(room_version.clone(), RoomVersionStability::Stable); } let mut capabilities = Capabilities::new(); capabilities.room_versions = RoomVersionsCapability { - default: db.globals.default_room_version(), + default: services().globals.default_room_version(), available, }; diff --git a/src/api/client_server/config.rs b/src/api/client_server/config.rs index 6184e0b..36f4fcb 100644 --- a/src/api/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{ config::{ @@ -17,7 +17,6 @@ use serde_json::{json, value::RawValue as RawJsonValue}; /// /// Sets some account data for the sender user. pub async fn set_global_account_data_route( - db: DatabaseGuard, body: Ruma<set_global_account_data::v3::IncomingRequest>, ) -> Result<set_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -27,7 +26,7 @@ pub async fn set_global_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( None, sender_user, event_type.clone().into(), @@ -35,11 +34,8 @@ pub async fn set_global_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_global_account_data::v3::Response {}) } @@ -47,7 +43,6 @@ pub async fn set_global_account_data_route( /// /// Sets some room account data for the sender user. pub async fn set_room_account_data_route( - db: DatabaseGuard, body: Ruma<set_room_account_data::v3::IncomingRequest>, ) -> Result<set_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -57,7 +52,7 @@ pub async fn set_room_account_data_route( let event_type = body.event_type.to_string(); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, event_type.clone().into(), @@ -65,11 +60,8 @@ pub async fn set_room_account_data_route( "type": event_type, "content": data, }), - &db.globals, )?; - db.flush()?; - Ok(set_room_account_data::v3::Response {}) } @@ -77,12 +69,11 @@ pub async fn set_room_account_data_route( /// /// Gets some account data for the sender user. pub async fn get_global_account_data_route( - db: DatabaseGuard, body: Ruma<get_global_account_data::v3::IncomingRequest>, ) -> Result<get_global_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = db + let event: Box<RawJsonValue> = services() .account_data .get(None, sender_user, body.event_type.clone().into())? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; @@ -98,12 +89,11 @@ pub async fn get_global_account_data_route( /// /// Gets some room account data for the sender user. pub async fn get_room_account_data_route( - db: DatabaseGuard, body: Ruma<get_room_account_data::v3::IncomingRequest>, ) -> Result<get_room_account_data::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box<RawJsonValue> = db + let event: Box<RawJsonValue> = services() .account_data .get( Some(&body.room_id), diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index e93f5a5..3551dcf 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, events::StateEventType, @@ -13,7 +13,6 @@ use tracing::error; /// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// joined, depending on history_visibility) pub async fn get_context_route( - db: DatabaseGuard, body: Ruma<get_context::v3::IncomingRequest>, ) -> Result<get_context::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -28,7 +27,7 @@ pub async fn get_context_route( let mut lazy_loaded = HashSet::new(); - let base_pdu_id = db + let base_pdu_id = services() .rooms .get_pdu_id(&body.event_id)? .ok_or(Error::BadRequest( @@ -36,9 +35,9 @@ pub async fn get_context_route( "Base event id not found.", ))?; - let base_token = db.rooms.pdu_count(&base_pdu_id)?; + let base_token = services().rooms.pdu_count(&base_pdu_id)?; - let base_event = db + let base_event = services() .rooms .get_pdu_from_id(&base_pdu_id)? .ok_or(Error::BadRequest( @@ -48,14 +47,14 @@ pub async fn get_context_route( let room_id = base_event.room_id.clone(); - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -67,7 +66,7 @@ pub async fn get_context_route( let base_event = base_event.to_room_event(); - let events_before: Vec<_> = db + let events_before: Vec<_> = services() .rooms .pdus_until(sender_user, &room_id, base_token)? .take( @@ -80,7 +79,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -93,7 +92,7 @@ pub async fn get_context_route( let start_token = events_before .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_before: Vec<_> = events_before @@ -101,7 +100,7 @@ pub async fn get_context_route( .map(|(_, pdu)| pdu.to_room_event()) .collect(); - let events_after: Vec<_> = db + let events_after: Vec<_> = services() .rooms .pdus_after(sender_user, &room_id, base_token)? .take( @@ -114,7 +113,7 @@ pub async fn get_context_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &room_id, @@ -125,23 +124,23 @@ pub async fn get_context_route( } } - let shortstatehash = match db.rooms.pdu_shortstatehash( + let shortstatehash = match services().rooms.pdu_shortstatehash( events_after .last() .map_or(&*body.event_id, |(_, e)| &*e.event_id), )? { Some(s) => s, - None => db + None => services() .rooms .current_shortstatehash(&room_id)? .expect("All rooms have state"), }; - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_full_ids(shortstatehash).await?; let end_token = events_after .last() - .and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) + .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok()) .map(|count| count.to_string()); let events_after: Vec<_> = events_after @@ -152,10 +151,10 @@ pub async fn get_context_route( let mut state = Vec::new(); for (shortstatekey, id) in state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -164,7 +163,7 @@ pub async fn get_context_route( }; state.push(pdu.to_state_event()); } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); diff --git a/src/api/client_server/device.rs b/src/api/client_server/device.rs index b100bf2..2f55993 100644 --- a/src/api/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::api::client::{ device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, error::ErrorKind, @@ -11,12 +11,11 @@ use super::SESSION_ID_LENGTH; /// /// Get metadata on all devices of the sender user. pub async fn get_devices_route( - db: DatabaseGuard, body: Ruma<get_devices::v3::Request>, ) -> Result<get_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let devices: Vec<device::Device> = db + let devices: Vec<device::Device> = services() .users .all_devices_metadata(sender_user) .filter_map(|r| r.ok()) // Filter out buggy devices @@ -29,12 +28,11 @@ pub async fn get_devices_route( /// /// Get metadata on a single device of the sender user. pub async fn get_device_route( - db: DatabaseGuard, body: Ruma<get_device::v3::IncomingRequest>, ) -> Result<get_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device = db + let device = services() .users .get_device_metadata(sender_user, &body.body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; @@ -46,23 +44,20 @@ pub async fn get_device_route( /// /// Updates the metadata on a given device of the sender user. pub async fn update_device_route( - db: DatabaseGuard, body: Ruma<update_device::v3::IncomingRequest>, ) -> Result<update_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device = db + let mut device = services() .users .get_device_metadata(sender_user, &body.device_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; device.display_name = body.display_name.clone(); - db.users + services().users .update_device_metadata(sender_user, &body.device_id, &device)?; - db.flush()?; - Ok(update_device::v3::Response {}) } @@ -76,7 +71,6 @@ pub async fn update_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_device_route( - db: DatabaseGuard, body: Ruma<delete_device::v3::IncomingRequest>, ) -> Result<delete_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -94,13 +88,11 @@ pub async fn delete_device_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -108,16 +100,14 @@ pub async fn delete_device_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); } - db.users.remove_device(sender_user, &body.device_id)?; - - db.flush()?; + services().users.remove_device(sender_user, &body.device_id)?; Ok(delete_device::v3::Response {}) } @@ -134,7 +124,6 @@ pub async fn delete_device_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn delete_devices_route( - db: DatabaseGuard, body: Ruma<delete_devices::v3::IncomingRequest>, ) -> Result<delete_devices::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -152,13 +141,11 @@ pub async fn delete_devices_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -166,7 +153,7 @@ pub async fn delete_devices_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -174,10 +161,8 @@ pub async fn delete_devices_route( } for device_id in &body.devices { - db.users.remove_device(sender_user, device_id)? + services().users.remove_device(sender_user, device_id)? } - db.flush()?; - Ok(delete_devices::v3::Response {}) } diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 4e4a322..87493fa 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::{ client::{ @@ -37,11 +37,9 @@ use tracing::{info, warn}; /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, body: Ruma<get_public_rooms_filtered::v3::IncomingRequest>, ) -> Result<get_public_rooms_filtered::v3::Response> { get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -57,11 +55,9 @@ pub async fn get_public_rooms_filtered_route( /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_route( - db: DatabaseGuard, body: Ruma<get_public_rooms::v3::IncomingRequest>, ) -> Result<get_public_rooms::v3::Response> { let response = get_public_rooms_filtered_helper( - &db, body.server.as_deref(), body.limit, body.since.as_deref(), @@ -84,17 +80,16 @@ pub async fn get_public_rooms_route( /// /// - TODO: Access control checks pub async fn set_room_visibility_route( - db: DatabaseGuard, body: Ruma<set_room_visibility::v3::IncomingRequest>, ) -> Result<set_room_visibility::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); match &body.visibility { room::Visibility::Public => { - db.rooms.set_public(&body.room_id, true)?; + services().rooms.set_public(&body.room_id, true)?; info!("{} made {} public", sender_user, body.room_id); } - room::Visibility::Private => db.rooms.set_public(&body.room_id, false)?, + room::Visibility::Private => services().rooms.set_public(&body.room_id, false)?, _ => { return Err(Error::BadRequest( ErrorKind::InvalidParam, @@ -103,8 +98,6 @@ pub async fn set_room_visibility_route( } } - db.flush()?; - Ok(set_room_visibility::v3::Response {}) } @@ -112,11 +105,10 @@ pub async fn set_room_visibility_route( /// /// Gets the visibility of a given room in the room directory. pub async fn get_room_visibility_route( - db: DatabaseGuard, body: Ruma<get_room_visibility::v3::IncomingRequest>, ) -> Result<get_room_visibility::v3::Response> { Ok(get_room_visibility::v3::Response { - visibility: if db.rooms.is_public_room(&body.room_id)? { + visibility: if services().rooms.is_public_room(&body.room_id)? { room::Visibility::Public } else { room::Visibility::Private @@ -125,19 +117,17 @@ pub async fn get_room_visibility_route( } pub(crate) async fn get_public_rooms_filtered_helper( - db: &Database, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &IncomingFilter, _network: &IncomingRoomNetwork, ) -> Result<get_public_rooms_filtered::v3::Response> { - if let Some(other_server) = server.filter(|server| *server != db.globals.server_name().as_str()) + if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) { - let response = db + let response = services() .sending .send_federation_request( - &db.globals, other_server, federation::directory::get_public_rooms_filtered::v1::Request { limit, @@ -184,14 +174,14 @@ pub(crate) async fn get_public_rooms_filtered_helper( } } - let mut all_rooms: Vec<_> = db + let mut all_rooms: Vec<_> = services() .rooms .public_rooms() .map(|room_id| { let room_id = room_id?; let chunk = PublicRoomsChunk { - canonical_alias: db + canonical_alias: services() .rooms .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .map_or(Ok(None), |s| { @@ -201,7 +191,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid canonical alias event in database.") }) })?, - name: db + name: services() .rooms .room_state_get(&room_id, &StateEventType::RoomName, "")? .map_or(Ok(None), |s| { @@ -211,7 +201,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room name event in database.") }) })?, - num_joined_members: db + num_joined_members: services() .rooms .room_joined_count(&room_id)? .unwrap_or_else(|| { @@ -220,7 +210,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( }) .try_into() .expect("user count should not be that big"), - topic: db + topic: services() .rooms .room_state_get(&room_id, &StateEventType::RoomTopic, "")? .map_or(Ok(None), |s| { @@ -230,7 +220,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room topic event in database.") }) })?, - world_readable: db + world_readable: services() .rooms .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .map_or(Ok(false), |s| { @@ -244,7 +234,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( ) }) })?, - guest_can_join: db + guest_can_join: services() .rooms .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .map_or(Ok(false), |s| { @@ -256,7 +246,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( Error::bad_database("Invalid room guest access event in database.") }) })?, - avatar_url: db + avatar_url: services() .rooms .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .map(|s| { @@ -269,7 +259,7 @@ pub(crate) async fn get_public_rooms_filtered_helper( .transpose()? // url is now an Option<String> so we must flatten .flatten(), - join_rule: db + join_rule: services() .rooms .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .map(|s| { diff --git a/src/api/client_server/filter.rs b/src/api/client_server/filter.rs index 6522c90..e0c9506 100644 --- a/src/api/client_server/filter.rs +++ b/src/api/client_server/filter.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ error::ErrorKind, filter::{create_filter, get_filter}, @@ -10,11 +10,10 @@ use ruma::api::client::{ /// /// - A user can only access their own filters pub async fn get_filter_route( - db: DatabaseGuard, body: Ruma<get_filter::v3::IncomingRequest>, ) -> Result<get_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let filter = match db.users.get_filter(sender_user, &body.filter_id)? { + let filter = match services().users.get_filter(sender_user, &body.filter_id)? { Some(filter) => filter, None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), }; @@ -26,11 +25,10 @@ pub async fn get_filter_route( /// /// Creates a new filter to be used by other endpoints. pub async fn create_filter_route( - db: DatabaseGuard, body: Ruma<create_filter::v3::IncomingRequest>, ) -> Result<create_filter::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(create_filter::v3::Response::new( - db.users.create_filter(sender_user, &body.filter)?, + services().users.create_filter(sender_user, &body.filter)?, )) } diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index c4f91cb..698bd1e 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,5 +1,5 @@ use super::SESSION_ID_LENGTH; -use crate::{database::DatabaseGuard, utils, Database, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::{ @@ -26,39 +26,34 @@ use std::collections::{BTreeMap, HashMap, HashSet}; /// - Adds one time keys /// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) pub async fn upload_keys_route( - db: DatabaseGuard, body: Ruma<upload_keys::v3::Request>, ) -> Result<upload_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); for (key_key, key_value) in &body.one_time_keys { - db.users - .add_one_time_key(sender_user, sender_device, key_key, key_value, &db.globals)?; + services().users + .add_one_time_key(sender_user, sender_device, key_key, key_value)?; } if let Some(device_keys) = &body.device_keys { // TODO: merge this and the existing event? // This check is needed to assure that signatures are kept - if db + if services() .users .get_device_keys(sender_user, sender_device)? .is_none() { - db.users.add_device_keys( + services().users.add_device_keys( sender_user, sender_device, device_keys, - &db.rooms, - &db.globals, )?; } } - db.flush()?; - Ok(upload_keys::v3::Response { - one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, + one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?, }) } @@ -70,7 +65,6 @@ pub async fn upload_keys_route( /// - Gets master keys, self-signing keys, user signing keys and device keys. /// - The master and self-signing keys contain signatures that the user is allowed to see pub async fn get_keys_route( - db: DatabaseGuard, body: Ruma<get_keys::v3::IncomingRequest>, ) -> Result<get_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -79,7 +73,6 @@ pub async fn get_keys_route( Some(sender_user), &body.device_keys, |u| u == sender_user, - &db, ) .await?; @@ -90,12 +83,9 @@ pub async fn get_keys_route( /// /// Claims one-time keys pub async fn claim_keys_route( - db: DatabaseGuard, body: Ruma<claim_keys::v3::Request>, ) -> Result<claim_keys::v3::Response> { - let response = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; + let response = claim_keys_helper(&body.one_time_keys).await?; Ok(response) } @@ -106,7 +96,6 @@ pub async fn claim_keys_route( /// /// - Requires UIAA to verify password pub async fn upload_signing_keys_route( - db: DatabaseGuard, body: Ruma<upload_signing_keys::v3::IncomingRequest>, ) -> Result<upload_signing_keys::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -124,13 +113,11 @@ pub async fn upload_signing_keys_route( }; if let Some(auth) = &body.auth { - let (worked, uiaainfo) = db.uiaa.try_auth( + let (worked, uiaainfo) = services().uiaa.try_auth( sender_user, sender_device, auth, &uiaainfo, - &db.users, - &db.globals, )?; if !worked { return Err(Error::Uiaa(uiaainfo)); @@ -138,7 +125,7 @@ pub async fn upload_signing_keys_route( // Success! } else if let Some(json) = body.json_body { uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - db.uiaa + services().uiaa .create(sender_user, sender_device, &uiaainfo, &json)?; return Err(Error::Uiaa(uiaainfo)); } else { @@ -146,18 +133,14 @@ pub async fn upload_signing_keys_route( } if let Some(master_key) = &body.master_key { - db.users.add_cross_signing_keys( + services().users.add_cross_signing_keys( sender_user, master_key, &body.self_signing_key, &body.user_signing_key, - &db.rooms, - &db.globals, )?; } - db.flush()?; - Ok(upload_signing_keys::v3::Response {}) } @@ -165,7 +148,6 @@ pub async fn upload_signing_keys_route( /// /// Uploads end-to-end key signatures from the sender user. pub async fn upload_signatures_route( - db: DatabaseGuard, body: Ruma<upload_signatures::v3::Request>, ) -> Result<upload_signatures::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -205,20 +187,16 @@ pub async fn upload_signatures_route( ))? .to_owned(), ); - db.users.sign_key( + services().users.sign_key( user_id, key_id, signature, sender_user, - &db.rooms, - &db.globals, )?; } } } - db.flush()?; - Ok(upload_signatures::v3::Response { failures: BTreeMap::new(), // TODO: integrate }) @@ -230,7 +208,6 @@ pub async fn upload_signatures_route( /// /// - TODO: left users pub async fn get_key_changes_route( - db: DatabaseGuard, body: Ruma<get_key_changes::v3::IncomingRequest>, ) -> Result<get_key_changes::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -238,7 +215,7 @@ pub async fn get_key_changes_route( let mut device_list_updates = HashSet::new(); device_list_updates.extend( - db.users + services().users .keys_changed( sender_user.as_str(), body.from @@ -253,9 +230,9 @@ pub async fn get_key_changes_route( .filter_map(|r| r.ok()), ); - for room_id in db.rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { + for room_id in services().rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { device_list_updates.extend( - db.users + services().users .keys_changed( &room_id.to_string(), body.from.parse().map_err(|_| { @@ -278,7 +255,6 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( sender_user: Option<&UserId>, device_keys_input: &BTreeMap<Box<UserId>, Vec<Box<DeviceId>>>, allowed_signatures: F, - db: &Database, ) -> Result<get_keys::v3::Response> { let mut master_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new(); @@ -290,7 +266,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( for (user_id, device_ids) in device_keys_input { let user_id: &UserId = &**user_id; - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -300,10 +276,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( if device_ids.is_empty() { let mut container = BTreeMap::new(); - for device_id in db.users.all_device_ids(user_id) { + for device_id in services().users.all_device_ids(user_id) { let device_id = device_id?; - if let Some(mut keys) = db.users.get_device_keys(user_id, &device_id)? { - let metadata = db + if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { + let metadata = services() .users .get_device_metadata(user_id, &device_id)? .ok_or_else(|| { @@ -319,8 +295,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } else { for device_id in device_ids { let mut container = BTreeMap::new(); - if let Some(mut keys) = db.users.get_device_keys(user_id, device_id)? { - let metadata = db.users.get_device_metadata(user_id, device_id)?.ok_or( + if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { + let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or( Error::BadRequest( ErrorKind::InvalidParam, "Tried to get keys for nonexistent device.", @@ -335,17 +311,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } } - if let Some(master_key) = db.users.get_master_key(user_id, &allowed_signatures)? { + if let Some(master_key) = services().users.get_master_key(user_id, &allowed_signatures)? { master_keys.insert(user_id.to_owned(), master_key); } - if let Some(self_signing_key) = db + if let Some(self_signing_key) = services() .users .get_self_signing_key(user_id, &allowed_signatures)? { self_signing_keys.insert(user_id.to_owned(), self_signing_key); } if Some(user_id) == sender_user { - if let Some(user_signing_key) = db.users.get_user_signing_key(user_id)? { + if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { user_signing_keys.insert(user_id.to_owned(), user_signing_key); } } @@ -362,9 +338,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>( } ( server, - db.sending + services().sending .send_federation_request( - &db.globals, server, federation::keys::get_keys::v1::Request { device_keys: device_keys_input_fed, @@ -417,14 +392,13 @@ fn add_unsigned_device_display_name( pub(crate) async fn claim_keys_helper( one_time_keys_input: &BTreeMap<Box<UserId>, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>, - db: &Database, ) -> Result<claim_keys::v3::Response> { let mut one_time_keys = BTreeMap::new(); let mut get_over_federation = BTreeMap::new(); for (user_id, map) in one_time_keys_input { - if user_id.server_name() != db.globals.server_name() { + if user_id.server_name() != services().globals.server_name() { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) @@ -434,8 +408,8 @@ pub(crate) async fn claim_keys_helper( let mut container = BTreeMap::new(); for (device_id, key_algorithm) in map { if let Some(one_time_keys) = - db.users - .take_one_time_key(user_id, device_id, key_algorithm, &db.globals)? + services().users + .take_one_time_key(user_id, device_id, key_algorithm)? { let mut c = BTreeMap::new(); c.insert(one_time_keys.0, one_time_keys.1); @@ -453,10 +427,9 @@ pub(crate) async fn claim_keys_helper( one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); } // Ignore failures - if let Ok(keys) = db + if let Ok(keys) = services() .sending .send_federation_request( - &db.globals, server, federation::keys::claim_keys::v1::Request { one_time_keys: one_time_keys_input_fed, diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index a9a6d6c..f0da084 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,6 +1,5 @@ use crate::{ - database::{media::FileMeta, DatabaseGuard}, - utils, Error, Result, Ruma, + utils, Error, Result, Ruma, services, service::media::FileMeta, }; use ruma::api::client::{ error::ErrorKind, @@ -16,11 +15,10 @@ const MXC_LENGTH: usize = 32; /// /// Returns max upload size. pub async fn get_media_config_route( - db: DatabaseGuard, _body: Ruma<get_media_config::v3::Request>, ) -> Result<get_media_config::v3::Response> { Ok(get_media_config::v3::Response { - upload_size: db.globals.max_request_size().into(), + upload_size: services().globals.max_request_size().into(), }) } @@ -31,19 +29,17 @@ pub async fn get_media_config_route( /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory pub async fn create_content_route( - db: DatabaseGuard, body: Ruma<create_content::v3::IncomingRequest>, ) -> Result<create_content::v3::Response> { let mxc = format!( "mxc://{}/{}", - db.globals.server_name(), + services().globals.server_name(), utils::random_string(MXC_LENGTH) ); - db.media + services().media .create( mxc.clone(), - &db.globals, &body .filename .as_ref() @@ -54,8 +50,6 @@ pub async fn create_content_route( ) .await?; - db.flush()?; - Ok(create_content::v3::Response { content_uri: mxc.try_into().expect("Invalid mxc:// URI"), blurhash: None, @@ -63,15 +57,13 @@ pub async fn create_content_route( } pub async fn get_remote_content( - db: &DatabaseGuard, mxc: &str, server_name: &ruma::ServerName, media_id: &str, ) -> Result<get_content::v3::Response, Error> { - let content_response = db + let content_response = services() .sending .send_federation_request( - &db.globals, server_name, get_content::v3::Request { allow_remote: false, @@ -81,10 +73,9 @@ pub async fn get_remote_content( ) .await?; - db.media + services().media .create( mxc.to_string(), - &db.globals, &content_response.content_disposition.as_deref(), &content_response.content_type.as_deref(), &content_response.file, @@ -100,7 +91,6 @@ pub async fn get_remote_content( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_route( - db: DatabaseGuard, body: Ruma<get_content::v3::IncomingRequest>, ) -> Result<get_content::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -109,16 +99,16 @@ pub async fn get_content_route( content_disposition, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(&mxc).await? { Ok(get_content::v3::Response { file, content_type, content_disposition, }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, &body.media_id).await?; Ok(remote_content_response) } else { Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) @@ -131,7 +121,6 @@ pub async fn get_content_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_as_filename_route( - db: DatabaseGuard, body: Ruma<get_content_as_filename::v3::IncomingRequest>, ) -> Result<get_content_as_filename::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); @@ -140,16 +129,16 @@ pub async fn get_content_as_filename_route( content_disposition: _, content_type, file, - }) = db.media.get(&db.globals, &mxc).await? + }) = services().media.get(&mxc).await? { Ok(get_content_as_filename::v3::Response { file, content_type, content_disposition: Some(format!("inline; filename={}", body.filename)), }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { let remote_content_response = - get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; + get_remote_content(&mxc, &body.server_name, &body.media_id).await?; Ok(get_content_as_filename::v3::Response { content_disposition: Some(format!("inline: filename={}", body.filename)), @@ -167,18 +156,16 @@ pub async fn get_content_as_filename_route( /// /// - Only allows federation if `allow_remote` is true pub async fn get_content_thumbnail_route( - db: DatabaseGuard, body: Ruma<get_content_thumbnail::v3::IncomingRequest>, ) -> Result<get_content_thumbnail::v3::Response> { let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); if let Some(FileMeta { content_type, file, .. - }) = db + }) = services() .media .get_thumbnail( &mxc, - &db.globals, body.width .try_into() .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, @@ -189,11 +176,10 @@ pub async fn get_content_thumbnail_route( .await? { Ok(get_content_thumbnail::v3::Response { file, content_type }) - } else if &*body.server_name != db.globals.server_name() && body.allow_remote { - let get_thumbnail_response = db + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + let get_thumbnail_response = services() .sending .send_federation_request( - &db.globals, &body.server_name, get_content_thumbnail::v3::Request { allow_remote: false, @@ -206,10 +192,9 @@ pub async fn get_content_thumbnail_route( ) .await?; - db.media + services().media .upload_thumbnail( mxc, - &db.globals, &None, &get_thumbnail_response.content_type, body.width.try_into().expect("all UInts are valid u32s"), diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index ecd26d1..b000ec1 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1,9 +1,3 @@ -use crate::{ - client_server, - database::DatabaseGuard, - pdu::{EventHash, PduBuilder, PduEvent}, - server_server, utils, Database, Error, Result, Ruma, -}; use ruma::{ api::{ client::{ @@ -29,13 +23,17 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap}, + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, iter, sync::{Arc, RwLock}, time::{Duration, Instant}, }; use tracing::{debug, error, warn}; +use crate::{services, PduEvent, service::pdu::{gen_event_id_canonical_json, PduBuilder}, Error, api::{server_server}, utils, Ruma}; + +use super::get_alias_helper; + /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. @@ -43,14 +41,13 @@ use tracing::{debug, error, warn}; /// - If the server knowns about this room: creates the join event and does auth rules locally /// - If the server does not know about the room: asks other servers over federation pub async fn join_room_by_id_route( - db: DatabaseGuard, body: Ruma<join_room_by_id::v3::IncomingRequest>, ) -> Result<join_room_by_id::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut servers = Vec::new(); // There is no body.server_name for /roomId/join servers.extend( - db.rooms + services().rooms .invite_state(sender_user, &body.room_id)? .unwrap_or_default() .iter() @@ -64,7 +61,6 @@ pub async fn join_room_by_id_route( servers.push(body.room_id.server_name().to_owned()); let ret = join_room_by_id_helper( - &db, body.sender_user.as_deref(), &body.room_id, &servers, @@ -72,8 +68,6 @@ pub async fn join_room_by_id_route( ) .await; - db.flush()?; - ret } @@ -84,7 +78,6 @@ pub async fn join_room_by_id_route( /// - If the server knowns about this room: creates the join event and does auth rules locally /// - If the server does not know about the room: asks other servers over federation pub async fn join_room_by_id_or_alias_route( - db: DatabaseGuard, body: Ruma<join_room_by_id_or_alias::v3::IncomingRequest>, ) -> Result<join_room_by_id_or_alias::v3::Response> { let sender_user = body.sender_user.as_deref().expect("user is authenticated"); @@ -94,7 +87,7 @@ pub async fn join_room_by_id_or_alias_route( Ok(room_id) => { let mut servers = body.server_name.clone(); servers.extend( - db.rooms + services().rooms .invite_state(sender_user, &room_id)? .unwrap_or_default() .iter() @@ -109,14 +102,13 @@ pub async fn join_room_by_id_or_alias_route( (servers, room_id) } Err(room_alias) => { - let response = client_server::get_alias_helper(&db, &room_alias).await?; + let response = get_alias_helper(&room_alias).await?; (response.servers.into_iter().collect(), response.room_id) } }; let join_room_response = join_room_by_id_helper( - &db, Some(sender_user), &room_id, &servers, @@ -124,8 +116,6 @@ pub async fn join_room_by_id_or_alias_route( ) .await?; - db.flush()?; - Ok(join_room_by_id_or_alias::v3::Response { room_id: join_room_response.room_id, }) @@ -137,14 +127,11 @@ pub async fn join_room_by_id_or_alias_route( /// /// - This should always work if the user is currently joined. pub async fn leave_room_route( - db: DatabaseGuard, body: Ruma<leave_room::v3::IncomingRequest>, ) -> Result<leave_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.leave_room(sender_user, &body.room_id, &db).await?; - - db.flush()?; + services().rooms.leave_room(sender_user, &body.room_id).await?; Ok(leave_room::v3::Response::new()) } @@ -153,14 +140,12 @@ pub async fn leave_room_route( /// /// Tries to send an invite event into the room. pub async fn invite_user_route( - db: DatabaseGuard, body: Ruma<invite_user::v3::IncomingRequest>, ) -> Result<invite_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); if let invite_user::v3::IncomingInvitationRecipient::UserId { user_id } = &body.recipient { - invite_helper(sender_user, user_id, &body.room_id, &db, false).await?; - db.flush()?; + invite_helper(sender_user, user_id, &body.room_id, false).await?; Ok(invite_user::v3::Response {}) } else { Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) @@ -171,13 +156,12 @@ pub async fn invite_user_route( /// /// Tries to send a kick event into the room. pub async fn kick_user_route( - db: DatabaseGuard, body: Ruma<kick_user::v3::IncomingRequest>, ) -> Result<kick_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -196,7 +180,7 @@ pub async fn kick_user_route( // TODO: reason let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -205,7 +189,7 @@ pub async fn kick_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -215,14 +199,11 @@ pub async fn kick_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(kick_user::v3::Response::new()) } @@ -230,14 +211,13 @@ pub async fn kick_user_route( /// /// Tries to send a ban event into the room. pub async fn ban_user_route( - db: DatabaseGuard, body: Ruma<ban_user::v3::IncomingRequest>, ) -> Result<ban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // TODO: reason - let event = db + let event = services() .rooms .room_state_get( &body.room_id, @@ -247,11 +227,11 @@ pub async fn ban_user_route( .map_or( Ok(RoomMemberEventContent { membership: MembershipState::Ban, - displayname: db.users.displayname(&body.user_id)?, - avatar_url: db.users.avatar_url(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, reason: None, join_authorized_via_users_server: None, }), @@ -266,7 +246,7 @@ pub async fn ban_user_route( )?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -275,7 +255,7 @@ pub async fn ban_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -285,14 +265,11 @@ pub async fn ban_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(ban_user::v3::Response::new()) } @@ -300,13 +277,12 @@ pub async fn ban_user_route( /// /// Tries to send an unban event into the room. pub async fn unban_user_route( - db: DatabaseGuard, body: Ruma<unban_user::v3::IncomingRequest>, ) -> Result<unban_user::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut event: RoomMemberEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get( &body.room_id, &StateEventType::RoomMember, @@ -324,7 +300,7 @@ pub async fn unban_user_route( event.membership = MembershipState::Leave; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -333,7 +309,7 @@ pub async fn unban_user_route( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -343,14 +319,11 @@ pub async fn unban_user_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - Ok(unban_user::v3::Response::new()) } @@ -363,14 +336,11 @@ pub async fn unban_user_route( /// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to /// be called from every device pub async fn forget_room_route( - db: DatabaseGuard, body: Ruma<forget_room::v3::IncomingRequest>, ) -> Result<forget_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.forget(&body.room_id, sender_user)?; - - db.flush()?; + services().rooms.forget(&body.room_id, sender_user)?; Ok(forget_room::v3::Response::new()) } @@ -379,13 +349,12 @@ pub async fn forget_room_route( /// /// Lists all rooms the user has joined. pub async fn joined_rooms_route( - db: DatabaseGuard, body: Ruma<joined_rooms::v3::Request>, ) -> Result<joined_rooms::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(joined_rooms::v3::Response { - joined_rooms: db + joined_rooms: services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -399,13 +368,12 @@ pub async fn joined_rooms_route( /// /// - Only works if the user is currently joined pub async fn get_member_events_route( - db: DatabaseGuard, body: Ruma<get_member_events::v3::IncomingRequest>, ) -> Result<get_member_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // TODO: check history visibility? - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -413,7 +381,7 @@ pub async fn get_member_events_route( } Ok(get_member_events::v3::Response { - chunk: db + chunk: services() .rooms .room_state_full(&body.room_id) .await? @@ -431,12 +399,11 @@ pub async fn get_member_events_route( /// - The sender user must be in the room /// - TODO: An appservice just needs a puppet joined pub async fn joined_members_route( - db: DatabaseGuard, body: Ruma<joined_members::v3::IncomingRequest>, ) -> Result<joined_members::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You aren't a member of the room.", @@ -444,9 +411,9 @@ pub async fn joined_members_route( } let mut joined = BTreeMap::new(); - for user_id in db.rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { - let display_name = db.users.displayname(&user_id)?; - let avatar_url = db.users.avatar_url(&user_id)?; + for user_id in services().rooms.room_members(&body.room_id).filter_map(|r| r.ok()) { + let display_name = services().users.displayname(&user_id)?; + let avatar_url = services().users.avatar_url(&user_id)?; joined.insert( user_id, @@ -460,9 +427,7 @@ pub async fn joined_members_route( Ok(joined_members::v3::Response { joined }) } -#[tracing::instrument(skip(db))] async fn join_room_by_id_helper( - db: &Database, sender_user: Option<&UserId>, room_id: &RoomId, servers: &[Box<ServerName>], @@ -471,7 +436,7 @@ async fn join_room_by_id_helper( let sender_user = sender_user.expect("user is authenticated"); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -481,21 +446,20 @@ async fn join_room_by_id_helper( let state_lock = mutex_state.lock().await; // Ask a remote server if we don't have this room - if !db.rooms.exists(room_id)? { + if !services().rooms.exists(room_id)? { let mut make_join_response_and_server = Err(Error::BadServerResponse( "No server available to assist in joining.", )); for remote_server in servers { - let make_join_response = db + let make_join_response = services() .sending .send_federation_request( - &db.globals, remote_server, federation::membership::prepare_join_event::v1::Request { room_id, user_id: sender_user, - ver: &db.globals.supported_room_versions(), + ver: &services().globals.supported_room_versions(), }, ) .await; @@ -510,7 +474,7 @@ async fn join_room_by_id_helper( let (make_join_response, remote_server) = make_join_response_and_server?; let room_version = match make_join_response.room_version { - Some(room_version) if db.rooms.is_supported_version(&db, &room_version) => room_version, + Some(room_version) if services().rooms.is_supported_version(&room_version) => room_version, _ => return Err(Error::BadServerResponse("Room version is not supported")), }; @@ -522,7 +486,7 @@ async fn join_room_by_id_helper( // TODO: Is origin needed? join_event_stub.insert( "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), ); join_event_stub.insert( "origin_server_ts".to_owned(), @@ -536,11 +500,11 @@ async fn join_room_by_id_helper( "content".to_owned(), to_canonical_value(RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -552,8 +516,8 @@ async fn join_room_by_id_helper( // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut join_event_stub, &room_version, ) @@ -577,10 +541,9 @@ async fn join_room_by_id_helper( // It has enough fields to be called a proper event now let join_event = join_event_stub; - let send_join_response = db + let send_join_response = services() .sending .send_federation_request( - &db.globals, remote_server, federation::membership::create_join_event::v2::Request { room_id, @@ -590,7 +553,7 @@ async fn join_room_by_id_helper( ) .await?; - db.rooms.get_or_create_shortroomid(room_id, &db.globals)?; + services().rooms.get_or_create_shortroomid(room_id, &services().globals)?; let parsed_pdu = PduEvent::from_id_val(event_id, join_event.clone()) .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; @@ -602,7 +565,6 @@ async fn join_room_by_id_helper( &send_join_response, &room_version, &pub_key_map, - db, ) .await?; @@ -610,7 +572,7 @@ async fn join_room_by_id_helper( .room_state .state .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) { let (event_id, value) = match result { Ok(t) => t, @@ -622,29 +584,27 @@ async fn join_room_by_id_helper( Error::BadServerResponse("Invalid PDU in send_join response.") })?; - db.rooms.add_pdu_outlier(&event_id, &value)?; + services().rooms.add_pdu_outlier(&event_id, &value)?; if let Some(state_key) = &pdu.state_key { - let shortstatekey = db.rooms.get_or_create_shortstatekey( + let shortstatekey = services().rooms.get_or_create_shortstatekey( &pdu.kind.to_string().into(), state_key, - &db.globals, )?; state.insert(shortstatekey, pdu.event_id.clone()); } } - let incoming_shortstatekey = db.rooms.get_or_create_shortstatekey( + let incoming_shortstatekey = services().rooms.get_or_create_shortstatekey( &parsed_pdu.kind.to_string().into(), parsed_pdu .state_key .as_ref() .expect("Pdu is a membership state event"), - &db.globals, )?; state.insert(incoming_shortstatekey, parsed_pdu.event_id.clone()); - let create_shortstatekey = db + let create_shortstatekey = services() .rooms .get_shortstatekey(&StateEventType::RoomCreate, "")? .expect("Room exists"); @@ -653,56 +613,54 @@ async fn join_room_by_id_helper( return Err(Error::BadServerResponse("State contained no create event.")); } - db.rooms.force_state( + services().rooms.force_state( room_id, state .into_iter() - .map(|(k, id)| db.rooms.compress_state_event(k, &id, &db.globals)) + .map(|(k, id)| services().rooms.compress_state_event(k, &id)) .collect::<Result<_>>()?, - db, )?; for result in send_join_response .room_state .auth_chain .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map, db)) + .map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) { let (event_id, value) = match result { Ok(t) => t, Err(_) => continue, }; - db.rooms.add_pdu_outlier(&event_id, &value)?; + services().rooms.add_pdu_outlier(&event_id, &value)?; } // We append to state before appending the pdu, so we don't have a moment in time with the // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = db.rooms.append_to_state(&parsed_pdu, &db.globals)?; + let statehashid = services().rooms.append_to_state(&parsed_pdu)?; - db.rooms.append_pdu( + services().rooms.append_pdu( &parsed_pdu, join_event, iter::once(&*parsed_pdu.event_id), - db, )?; // We set the room state after inserting the pdu, so that we never have a moment in time // where events in the current room state do not exist - db.rooms.set_room_state(room_id, statehashid)?; + services().rooms.set_room_state(room_id, statehashid)?; } else { let event = RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&event).expect("event is valid, we just created it"), @@ -712,15 +670,13 @@ async fn join_room_by_id_helper( }, sender_user, room_id, - db, + services(), &state_lock, )?; } drop(state_lock); - db.flush()?; - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } @@ -728,7 +684,6 @@ fn validate_and_add_event_id( pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, ) -> Result<(Box<EventId>, CanonicalJsonObject)> { let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -741,14 +696,14 @@ fn validate_and_add_event_id( )) .expect("ruma's reference hashes are valid event ids"); - let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { + let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { Entry::Vacant(e) => { e.insert((Instant::now(), 1)); } Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), }; - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() @@ -791,13 +746,12 @@ pub(crate) async fn invite_helper<'a>( sender_user: &UserId, user_id: &UserId, room_id: &RoomId, - db: &Database, is_direct: bool, ) -> Result<()> { - if user_id.server_name() != db.globals.server_name() { - let (room_version_id, pdu_json, invite_room_state) = { + if user_id.server_name() != services().globals.server_name() { + let (pdu_json, invite_room_state) = { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -818,36 +772,38 @@ pub(crate) async fn invite_helper<'a>( }) .expect("member event is valid value"); - let state_key = user_id.to_string(); - let kind = StateEventType::RoomMember; - - let (pdu, pdu_json) = create_hash_and_sign_event(); + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event(PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, sender_user, room_id, &state_lock); - let invite_room_state = db.rooms.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.calculate_invite_state(&pdu)?; drop(state_lock); - (room_version_id, pdu_json, invite_room_state) + (pdu_json, invite_room_state) }; // Generate event id let expected_event_id = format!( "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) + ruma::signatures::reference_hash(&pdu_json, &services().rooms.state.get_room_version(&room_id)?) .expect("ruma can calculate reference hashes") ); let expected_event_id = <&EventId>::try_from(expected_event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); - let response = db + let response = services() .sending .send_federation_request( - &db.globals, user_id.server_name(), create_invite::v2::Request { room_id, event_id: expected_event_id, - room_version: &room_version_id, + room_version: &services().state.get_room_version(&room_id)?, event: &PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state: &invite_room_state, }, @@ -857,7 +813,7 @@ pub(crate) async fn invite_helper<'a>( let pub_key_map = RwLock::new(BTreeMap::new()); // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(&response.event, &db) + let (event_id, value) = match gen_event_id_canonical_json(&response.event) { Ok(t) => t, Err(_) => { @@ -882,13 +838,12 @@ pub(crate) async fn invite_helper<'a>( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id = server_server::handle_incoming_pdu( + let pdu_id = services().rooms.event_handler.handle_incoming_pdu( &origin, &event_id, room_id, value, true, - db, &pub_key_map, ) .await @@ -903,18 +858,18 @@ pub(crate) async fn invite_helper<'a>( "Could not accept incoming PDU as timeline event.", ))?; - let servers = db + let servers = services() .rooms .room_servers(room_id) .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); + .filter(|server| &**server != services().globals.server_name()); - db.sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id)?; return Ok(()); } - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -922,7 +877,7 @@ pub(crate) async fn invite_helper<'a>( } let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -931,16 +886,16 @@ pub(crate) async fn invite_helper<'a>( ); let state_lock = mutex_state.lock().await; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Invite, - displayname: db.users.displayname(user_id)?, - avatar_url: db.users.avatar_url(user_id)?, + displayname: services().users.displayname(user_id)?, + avatar_url: services().users.avatar_url(user_id)?, is_direct: Some(is_direct), third_party_invite: None, - blurhash: db.users.blurhash(user_id)?, + blurhash: services().users.blurhash(user_id)?, reason: None, join_authorized_via_users_server: None, }) @@ -951,7 +906,6 @@ pub(crate) async fn invite_helper<'a>( }, sender_user, room_id, - db, &state_lock, )?; @@ -960,208 +914,196 @@ pub(crate) async fn invite_helper<'a>( Ok(()) } - // Make a user leave all their joined rooms - #[tracing::instrument(skip(self, db))] - pub async fn leave_all_rooms(&self, user_id: &UserId, db: &Database) -> Result<()> { - let all_rooms = db - .rooms - .rooms_joined(user_id) - .chain(db.rooms.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) - .collect::<Vec<_>>(); - - for room_id in all_rooms { - let room_id = match room_id { - Ok(room_id) => room_id, - Err(_) => continue, - }; - - let _ = self.leave_room(user_id, &room_id, db).await; - } +// Make a user leave all their joined rooms +pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> { + let all_rooms = services() + .rooms + .rooms_joined(user_id) + .chain(services().rooms.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) + .collect::<Vec<_>>(); + + for room_id in all_rooms { + let room_id = match room_id { + Ok(room_id) => room_id, + Err(_) => continue, + }; - Ok(()) + let _ = leave_room(user_id, &room_id).await; } - #[tracing::instrument(skip(self, db))] - pub async fn leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - // Ask a remote server if we don't have this room - if !self.exists(room_id)? && room_id.server_name() != db.globals.server_name() { - if let Err(e) = self.remote_leave_room(user_id, room_id, db).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); - // Don't tell the client about this error - } - - let last_state = self - .invite_state(user_id, room_id)? - .map_or_else(|| self.left_state(user_id, room_id), |s| Ok(Some(s)))?; + Ok(()) +} - // We always drop the invite, we can't rely on other servers - self.update_membership( - room_id, - user_id, - MembershipState::Leave, - user_id, - last_state, - db, - true, - )?; - } else { - let mutex_state = Arc::clone( - db.globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; +pub async fn leave_room( + user_id: &UserId, + room_id: &RoomId, +) -> Result<()> { + // Ask a remote server if we don't have this room + if !services().rooms.metadata.exists(room_id)? && room_id.server_name() != services().globals.server_name() { + if let Err(e) = remote_leave_room(user_id, room_id).await { + warn!("Failed to leave room {} remotely: {}", user_id, e); + // Don't tell the client about this error + } - let mut event: RoomMemberEventContent = serde_json::from_str( - self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot leave a room you are not a member of.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let last_state = services().rooms.state_cache + .invite_state(user_id, room_id)? + .map_or_else(|| services().rooms.left_state(user_id, room_id), |s| Ok(Some(s)))?; - event.membership = MembershipState::Leave; + // We always drop the invite, we can't rely on other servers + services().rooms.state_cache.update_membership( + room_id, + user_id, + MembershipState::Leave, + user_id, + last_state, + true, + )?; + } else { + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; + + let mut event: RoomMemberEventContent = serde_json::from_str( + services().rooms.state.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "Cannot leave a room you are not a member of.", + ))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; - self.build_and_append_pdu( - PduBuilder { - event_type: RoomEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - room_id, - db, - &state_lock, - )?; - } + event.membership = MembershipState::Leave; - Ok(()) + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: RoomEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + room_id, + &state_lock, + )?; } - #[tracing::instrument(skip(self, db))] - async fn remote_leave_room( - &self, - user_id: &UserId, - room_id: &RoomId, - db: &Database, - ) -> Result<()> { - let mut make_leave_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in leaving.", - )); + Ok(()) +} - let invite_state = db - .rooms - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "User is not invited.", - ))?; +async fn remote_leave_room( + user_id: &UserId, + room_id: &RoomId, +) -> Result<()> { + let mut make_leave_response_and_server = Err(Error::BadServerResponse( + "No server available to assist in leaving.", + )); - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let invite_state = services() + .rooms + .invite_state(user_id, room_id)? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "User is not invited.", + ))?; - for remote_server in servers { - let make_leave_response = db - .sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::prepare_leave_event::v1::Request { room_id, user_id }, - ) - .await; + let servers: HashSet<_> = invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(|s| s.to_owned())) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(); + + for remote_server in servers { + let make_leave_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::prepare_leave_event::v1::Request { room_id, user_id }, + ) + .await; - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); + make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); - if make_leave_response_and_server.is_ok() { - break; - } + if make_leave_response_and_server.is_ok() { + break; } + } - let (make_leave_response, remote_server) = make_leave_response_and_server?; - - let room_version_id = match make_leave_response.room_version { - Some(version) if self.is_supported_version(&db, &version) => version, - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - - let mut leave_event_stub = - serde_json::from_str::<CanonicalJsonObject>(make_leave_response.event.get()).map_err( - |_| Error::BadServerResponse("Invalid make_leave event json received from server."), - )?; + let (make_leave_response, remote_server) = make_leave_response_and_server?; - // TODO: Is origin needed? - leave_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(db.globals.server_name().as_str().to_owned()), - ); - leave_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - leave_event_stub.remove("event_id"); + let room_version_id = match make_leave_response.room_version { + Some(version) if services().rooms.is_supported_version(&version) => version, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + let mut leave_event_stub = + serde_json::from_str::<CanonicalJsonObject>(make_leave_response.event.get()).map_err( + |_| Error::BadServerResponse("Invalid make_leave event json received from server."), + )?; - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + // TODO: Is origin needed? + leave_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + leave_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch() + .try_into() + .expect("Timestamp is valid js_int value"), + ), + ); + // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms + leave_event_stub.remove("event_id"); + + // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut leave_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); - // Add event_id back - leave_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); - // It has enough fields to be called a proper event now - let leave_event = leave_event_stub; + // Add event_id back + leave_event_stub.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(event_id.as_str().to_owned()), + ); - db.sending - .send_federation_request( - &db.globals, - &remote_server, - federation::membership::create_leave_event::v2::Request { - room_id, - event_id: &event_id, - pdu: &PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), - }, - ) - .await?; + // It has enough fields to be called a proper event now + let leave_event = leave_event_stub; - Ok(()) - } + services().sending + .send_federation_request( + &remote_server, + federation::membership::create_leave_event::v2::Request { + room_id, + event_id: &event_id, + pdu: &PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + }, + ) + .await?; + Ok(()) +} diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 1348132..861f9c1 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::client::{ error::ErrorKind, @@ -19,14 +19,13 @@ use std::{ /// - The only requirement for the content is that it has to be valid json /// - Tries to send the event into the room, auth rules will determine if it is allowed pub async fn send_message_event_route( - db: DatabaseGuard, body: Ruma<send_message_event::v3::IncomingRequest>, ) -> Result<send_message_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -37,7 +36,7 @@ pub async fn send_message_event_route( // Forbid m.room.encrypted if encryption is disabled if RoomEventType::RoomEncrypted == body.event_type.to_string().into() - && !db.globals.allow_encryption() + && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -47,7 +46,7 @@ pub async fn send_message_event_route( // Check if this is a new transaction id if let Some(response) = - db.transaction_ids + services().transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? { // The client might have sent a txnid of the /sendToDevice endpoint @@ -69,7 +68,7 @@ pub async fn send_message_event_route( let mut unsigned = BTreeMap::new(); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: body.event_type.to_string().into(), content: serde_json::from_str(body.body.body.json().get()) @@ -80,11 +79,10 @@ pub async fn send_message_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; - db.transaction_ids.add_txnid( + services().transaction_ids.add_txnid( sender_user, sender_device, &body.txn_id, @@ -93,8 +91,6 @@ pub async fn send_message_event_route( drop(state_lock); - db.flush()?; - Ok(send_message_event::v3::Response::new( (*event_id).to_owned(), )) @@ -107,13 +103,12 @@ pub async fn send_message_event_route( /// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// joined, depending on history_visibility) pub async fn get_message_events_route( - db: DatabaseGuard, body: Ruma<get_message_events::v3::IncomingRequest>, ) -> Result<get_message_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -133,7 +128,7 @@ pub async fn get_message_events_route( let to = body.to.as_ref().map(|t| t.parse()); - db.rooms + services().rooms .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?; // Use limit or else 10 @@ -147,13 +142,13 @@ pub async fn get_message_events_route( match body.dir { get_message_events::v3::Direction::Forward => { - let events_after: Vec<_> = db + let events_after: Vec<_> = services() .rooms .pdus_after(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services().rooms .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -162,7 +157,7 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_after { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -184,13 +179,13 @@ pub async fn get_message_events_route( resp.chunk = events_after; } get_message_events::v3::Direction::Backward => { - let events_before: Vec<_> = db + let events_before: Vec<_> = services() .rooms .pdus_until(sender_user, &body.room_id, from)? .take(limit) .filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|(pdu_id, pdu)| { - db.rooms + services().rooms .pdu_count(&pdu_id) .map(|pdu_count| (pdu_count, pdu)) .ok() @@ -199,7 +194,7 @@ pub async fn get_message_events_route( .collect(); for (_, event) in &events_before { - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( sender_user, sender_device, &body.room_id, @@ -225,7 +220,7 @@ pub async fn get_message_events_route( resp.state = Vec::new(); for ll_id in &lazy_loaded { if let Some(member_event) = - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? { resp.state.push(member_event.to_state_event()); @@ -233,7 +228,7 @@ pub async fn get_message_events_route( } if let Some(next_token) = next_token { - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( sender_user, sender_device, &body.room_id, diff --git a/src/api/client_server/presence.rs b/src/api/client_server/presence.rs index 773fef4..bc220b8 100644 --- a/src/api/client_server/presence.rs +++ b/src/api/client_server/presence.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils, Result, Ruma}; +use crate::{utils, Result, Ruma, services}; use ruma::api::client::presence::{get_presence, set_presence}; use std::time::Duration; @@ -6,22 +6,21 @@ use std::time::Duration; /// /// Sets the presence state of the sender user. pub async fn set_presence_route( - db: DatabaseGuard, body: Ruma<set_presence::v3::IncomingRequest>, ) -> Result<set_presence::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for room_id in db.rooms.rooms_joined(sender_user) { + for room_id in services().rooms.rooms_joined(sender_user) { let room_id = room_id?; - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -32,12 +31,9 @@ pub async fn set_presence_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_presence::v3::Response {}) } @@ -47,20 +43,19 @@ pub async fn set_presence_route( /// /// - Only works if you share a room with the user pub async fn get_presence_route( - db: DatabaseGuard, body: Ruma<get_presence::v3::IncomingRequest>, ) -> Result<get_presence::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let mut presence_event = None; - for room_id in db + for room_id in services() .rooms .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { let room_id = room_id?; - if let Some(presence) = db + if let Some(presence) = services() .rooms .edus .get_last_presence_event(sender_user, &room_id)? diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index acea19f..7a87bcd 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::{ client::{ @@ -20,16 +20,15 @@ use std::sync::Arc; /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_displayname_route( - db: DatabaseGuard, body: Ruma<set_display_name::v3::IncomingRequest>, ) -> Result<set_display_name::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services().users .set_displayname(sender_user, body.displayname.clone())?; // Send a new membership event and presence update into all joined rooms - let all_rooms_joined: Vec<_> = db + let all_rooms_joined: Vec<_> = services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -40,7 +39,7 @@ pub async fn set_displayname_route( content: to_raw_value(&RoomMemberEventContent { displayname: body.displayname.clone(), ..serde_json::from_str( - db.rooms + services().rooms .room_state_get( &room_id, &StateEventType::RoomMember, @@ -70,7 +69,7 @@ pub async fn set_displayname_route( for (pdu_builder, room_id) in all_rooms_joined { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -79,19 +78,19 @@ pub async fn set_displayname_route( ); let state_lock = mutex_state.lock().await; - let _ = db + let _ = services() .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -102,12 +101,9 @@ pub async fn set_displayname_route( }, sender: sender_user.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_display_name::v3::Response {}) } @@ -117,14 +113,12 @@ pub async fn set_displayname_route( /// /// - If user is on another server: Fetches displayname over federation pub async fn get_displayname_route( - db: DatabaseGuard, body: Ruma<get_display_name::v3::IncomingRequest>, ) -> Result<get_display_name::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -139,7 +133,7 @@ pub async fn get_displayname_route( } Ok(get_display_name::v3::Response { - displayname: db.users.displayname(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } @@ -149,18 +143,17 @@ pub async fn get_displayname_route( /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_avatar_url_route( - db: DatabaseGuard, body: Ruma<set_avatar_url::v3::IncomingRequest>, ) -> Result<set_avatar_url::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.users + services().users .set_avatar_url(sender_user, body.avatar_url.clone())?; - db.users.set_blurhash(sender_user, body.blurhash.clone())?; + services().users.set_blurhash(sender_user, body.blurhash.clone())?; // Send a new membership event and presence update into all joined rooms - let all_joined_rooms: Vec<_> = db + let all_joined_rooms: Vec<_> = services() .rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) @@ -171,7 +164,7 @@ pub async fn set_avatar_url_route( content: to_raw_value(&RoomMemberEventContent { avatar_url: body.avatar_url.clone(), ..serde_json::from_str( - db.rooms + services().rooms .room_state_get( &room_id, &StateEventType::RoomMember, @@ -201,7 +194,7 @@ pub async fn set_avatar_url_route( for (pdu_builder, room_id) in all_joined_rooms { let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -210,19 +203,19 @@ pub async fn set_avatar_url_route( ); let state_lock = mutex_state.lock().await; - let _ = db + let _ = services() .rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock); // Presence update - db.rooms.edus.update_presence( + services().rooms.edus.update_presence( sender_user, &room_id, ruma::events::presence::PresenceEvent { content: ruma::events::presence::PresenceEventContent { - avatar_url: db.users.avatar_url(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, currently_active: None, - displayname: db.users.displayname(sender_user)?, + displayname: services().users.displayname(sender_user)?, last_active_ago: Some( utils::millis_since_unix_epoch() .try_into() @@ -233,12 +226,10 @@ pub async fn set_avatar_url_route( }, sender: sender_user.clone(), }, - &db.globals, + &services().globals, )?; } - db.flush()?; - Ok(set_avatar_url::v3::Response {}) } @@ -248,14 +239,12 @@ pub async fn set_avatar_url_route( /// /// - If user is on another server: Fetches avatar_url and blurhash over federation pub async fn get_avatar_url_route( - db: DatabaseGuard, body: Ruma<get_avatar_url::v3::IncomingRequest>, ) -> Result<get_avatar_url::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -271,8 +260,8 @@ pub async fn get_avatar_url_route( } Ok(get_avatar_url::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, }) } @@ -282,14 +271,12 @@ pub async fn get_avatar_url_route( /// /// - If user is on another server: Fetches profile over federation pub async fn get_profile_route( - db: DatabaseGuard, body: Ruma<get_profile::v3::IncomingRequest>, ) -> Result<get_profile::v3::Response> { - if body.user_id.server_name() != db.globals.server_name() { - let response = db + if body.user_id.server_name() != services().globals.server_name() { + let response = services() .sending .send_federation_request( - &db.globals, body.user_id.server_name(), federation::query::get_profile_information::v1::Request { user_id: &body.user_id, @@ -305,7 +292,7 @@ pub async fn get_profile_route( }); } - if !db.users.exists(&body.user_id)? { + if !services().users.exists(&body.user_id)? { // Return 404 if this user doesn't exist return Err(Error::BadRequest( ErrorKind::NotFound, @@ -314,8 +301,8 @@ pub async fn get_profile_route( } Ok(get_profile::v3::Response { - avatar_url: db.users.avatar_url(&body.user_id)?, - blurhash: db.users.blurhash(&body.user_id)?, - displayname: db.users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, }) } diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index dc45ea0..112fa00 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{ error::ErrorKind, @@ -16,12 +16,11 @@ use ruma::{ /// /// Retrieves the push rules event for this user. pub async fn get_pushrules_all_route( - db: DatabaseGuard, body: Ruma<get_pushrules_all::v3::Request>, ) -> Result<get_pushrules_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event: PushRulesEvent = services() .account_data .get( None, @@ -42,12 +41,11 @@ pub async fn get_pushrules_all_route( /// /// Retrieves a single specified push rule for this user. pub async fn get_pushrule_route( - db: DatabaseGuard, body: Ruma<get_pushrule::v3::IncomingRequest>, ) -> Result<get_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: PushRulesEvent = db + let event: PushRulesEvent = services() .account_data .get( None, @@ -98,7 +96,6 @@ pub async fn get_pushrule_route( /// /// Creates a single specified push rule for this user. pub async fn set_pushrule_route( - db: DatabaseGuard, body: Ruma<set_pushrule::v3::IncomingRequest>, ) -> Result<set_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -111,7 +108,7 @@ pub async fn set_pushrule_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -186,16 +183,13 @@ pub async fn set_pushrule_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule::v3::Response {}) } @@ -203,7 +197,6 @@ pub async fn set_pushrule_route( /// /// Gets the actions of a single specified push rule for this user. pub async fn get_pushrule_actions_route( - db: DatabaseGuard, body: Ruma<get_pushrule_actions::v3::IncomingRequest>, ) -> Result<get_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -215,7 +208,7 @@ pub async fn get_pushrule_actions_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -252,8 +245,6 @@ pub async fn get_pushrule_actions_route( _ => None, }; - db.flush()?; - Ok(get_pushrule_actions::v3::Response { actions: actions.unwrap_or_default(), }) @@ -263,7 +254,6 @@ pub async fn get_pushrule_actions_route( /// /// Sets the actions of a single specified push rule for this user. pub async fn set_pushrule_actions_route( - db: DatabaseGuard, body: Ruma<set_pushrule_actions::v3::IncomingRequest>, ) -> Result<set_pushrule_actions::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -275,7 +265,7 @@ pub async fn set_pushrule_actions_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -322,16 +312,13 @@ pub async fn set_pushrule_actions_route( _ => {} }; - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule_actions::v3::Response {}) } @@ -339,7 +326,6 @@ pub async fn set_pushrule_actions_route( /// /// Gets the enabled status of a single specified push rule for this user. pub async fn get_pushrule_enabled_route( - db: DatabaseGuard, body: Ruma<get_pushrule_enabled::v3::IncomingRequest>, ) -> Result<get_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -351,7 +337,7 @@ pub async fn get_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -393,8 +379,6 @@ pub async fn get_pushrule_enabled_route( _ => false, }; - db.flush()?; - Ok(get_pushrule_enabled::v3::Response { enabled }) } @@ -402,7 +386,6 @@ pub async fn get_pushrule_enabled_route( /// /// Sets the enabled status of a single specified push rule for this user. pub async fn set_pushrule_enabled_route( - db: DatabaseGuard, body: Ruma<set_pushrule_enabled::v3::IncomingRequest>, ) -> Result<set_pushrule_enabled::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -414,7 +397,7 @@ pub async fn set_pushrule_enabled_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -466,16 +449,13 @@ pub async fn set_pushrule_enabled_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(set_pushrule_enabled::v3::Response {}) } @@ -483,7 +463,6 @@ pub async fn set_pushrule_enabled_route( /// /// Deletes a single specified push rule for this user. pub async fn delete_pushrule_route( - db: DatabaseGuard, body: Ruma<delete_pushrule::v3::IncomingRequest>, ) -> Result<delete_pushrule::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -495,7 +474,7 @@ pub async fn delete_pushrule_route( )); } - let mut event: PushRulesEvent = db + let mut event: PushRulesEvent = services() .account_data .get( None, @@ -537,16 +516,13 @@ pub async fn delete_pushrule_route( _ => {} } - db.account_data.update( + services().account_data.update( None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into(), &event, - &db.globals, )?; - db.flush()?; - Ok(delete_pushrule::v3::Response {}) } @@ -554,13 +530,12 @@ pub async fn delete_pushrule_route( /// /// Gets all currently active pushers for the sender user. pub async fn get_pushers_route( - db: DatabaseGuard, body: Ruma<get_pushers::v3::Request>, ) -> Result<get_pushers::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_pushers::v3::Response { - pushers: db.pusher.get_pushers(sender_user)?, + pushers: services().pusher.get_pushers(sender_user)?, }) } @@ -570,15 +545,12 @@ pub async fn get_pushers_route( /// /// - TODO: Handle `append` pub async fn set_pushers_route( - db: DatabaseGuard, body: Ruma<set_pusher::v3::Request>, ) -> Result<set_pusher::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let pusher = body.pusher.clone(); - db.pusher.set_pusher(sender_user, pusher)?; - - db.flush()?; + services().pusher.set_pusher(sender_user, pusher)?; Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index 91988a4..284ae65 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, events::RoomAccountDataEventType, @@ -14,7 +14,6 @@ use std::collections::BTreeMap; /// - Updates fully-read account data event to `fully_read` /// - If `read_receipt` is set: Update private marker and public read receipt EDU pub async fn set_read_marker_route( - db: DatabaseGuard, body: Ruma<set_read_marker::v3::IncomingRequest>, ) -> Result<set_read_marker::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -24,25 +23,23 @@ pub async fn set_read_marker_route( event_id: body.fully_read.clone(), }, }; - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::FullyRead, &fully_read_event, - &db.globals, )?; if let Some(event) = &body.read_receipt { - db.rooms.edus.private_read_set( + services().rooms.edus.private_read_set( &body.room_id, sender_user, - db.rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( + services().rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Event does not exist.", ))?, - &db.globals, )?; - db.rooms + services().rooms .reset_notification_counts(sender_user, &body.room_id)?; let mut user_receipts = BTreeMap::new(); @@ -59,19 +56,16 @@ pub async fn set_read_marker_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(event.to_owned(), receipts); - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { content: ruma::events::receipt::ReceiptEventContent(receipt_content), room_id: body.room_id.clone(), }, - &db.globals, )?; } - db.flush()?; - Ok(set_read_marker::v3::Response {}) } @@ -79,23 +73,21 @@ pub async fn set_read_marker_route( /// /// Sets private read marker and public read receipt EDU. pub async fn create_receipt_route( - db: DatabaseGuard, body: Ruma<create_receipt::v3::IncomingRequest>, ) -> Result<create_receipt::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - db.rooms.edus.private_read_set( + services().rooms.edus.private_read_set( &body.room_id, sender_user, - db.rooms + services().rooms .get_pdu_count(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Event does not exist.", ))?, - &db.globals, )?; - db.rooms + services().rooms .reset_notification_counts(sender_user, &body.room_id)?; let mut user_receipts = BTreeMap::new(); @@ -111,17 +103,16 @@ pub async fn create_receipt_route( let mut receipt_content = BTreeMap::new(); receipt_content.insert(body.event_id.to_owned(), receipts); - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( sender_user, &body.room_id, ruma::events::receipt::ReceiptEvent { content: ruma::events::receipt::ReceiptEventContent(receipt_content), room_id: body.room_id.clone(), }, - &db.globals, )?; - db.flush()?; + services().flush()?; Ok(create_receipt::v3::Response {}) } diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index 059e0f5..d6699bc 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::{database::DatabaseGuard, pdu::PduBuilder, Result, Ruma}; +use crate::{Result, Ruma, services, service::pdu::PduBuilder}; use ruma::{ api::client::redact::redact_event, events::{room::redaction::RoomRedactionEventContent, RoomEventType}, @@ -14,14 +14,13 @@ use serde_json::value::to_raw_value; /// /// - TODO: Handle txn id pub async fn redact_event_route( - db: DatabaseGuard, body: Ruma<redact_event::v3::IncomingRequest>, ) -> Result<redact_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let body = body.body; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -30,7 +29,7 @@ pub async fn redact_event_route( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomRedaction, content: to_raw_value(&RoomRedactionEventContent { @@ -43,14 +42,11 @@ pub async fn redact_event_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(redact_event::v3::Response { event_id }) } diff --git a/src/api/client_server/report.rs b/src/api/client_server/report.rs index 14768e1..2c2a549 100644 --- a/src/api/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, utils::HtmlEscape, Error, Result, Ruma}; +use crate::{utils::HtmlEscape, Error, Result, Ruma, services}; use ruma::{ api::client::{error::ErrorKind, room::report_content}, events::room::message, @@ -10,12 +10,11 @@ use ruma::{ /// Reports an inappropriate event to homeserver admins /// pub async fn report_event_route( - db: DatabaseGuard, body: Ruma<report_content::v3::IncomingRequest>, ) -> Result<report_content::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let pdu = match db.rooms.get_pdu(&body.event_id)? { + let pdu = match services().rooms.get_pdu(&body.event_id)? { Some(pdu) => pdu, _ => { return Err(Error::BadRequest( @@ -39,7 +38,7 @@ pub async fn report_event_route( )); }; - db.admin + services().admin .send_message(message::RoomMessageEventContent::text_html( format!( "Report received from: {}\n\n\ @@ -66,7 +65,5 @@ pub async fn report_event_route( ), )); - db.flush()?; - Ok(report_content::v3::Response {}) } diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index 5ae7224..14affc6 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,5 +1,5 @@ use crate::{ - client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, Error, Result, Ruma, + Error, Result, Ruma, service::pdu::PduBuilder, services, api::client_server::invite_helper, }; use ruma::{ api::client::{ @@ -46,19 +46,18 @@ use tracing::{info, warn}; /// - Send events implied by `name` and `topic` /// - Send invite events pub async fn create_room_route( - db: DatabaseGuard, body: Ruma<create_room::v3::IncomingRequest>, ) -> Result<create_room::v3::Response> { use create_room::v3::RoomPreset; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let room_id = RoomId::new(db.globals.server_name()); + let room_id = RoomId::new(services().globals.server_name()); - db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; + services().rooms.get_or_create_shortroomid(&room_id)?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -67,9 +66,9 @@ pub async fn create_room_route( ); let state_lock = mutex_state.lock().await; - if !db.globals.allow_room_creation() + if !services().globals.allow_room_creation() && !body.from_appservice - && !db.users.is_admin(sender_user, &db.rooms, &db.globals)? + && !services().users.is_admin(sender_user)? { return Err(Error::BadRequest( ErrorKind::Forbidden, @@ -83,12 +82,12 @@ pub async fn create_room_route( .map_or(Ok(None), |localpart| { // TODO: Check for invalid characters and maximum length let alias = - RoomAliasId::parse(format!("#{}:{}", localpart, db.globals.server_name())) + RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") })?; - if db.rooms.id_from_alias(&alias)?.is_some() { + if services().rooms.id_from_alias(&alias)?.is_some() { Err(Error::BadRequest( ErrorKind::RoomInUse, "Room alias already exists.", @@ -100,7 +99,7 @@ pub async fn create_room_route( let room_version = match body.room_version.clone() { Some(room_version) => { - if db.rooms.is_supported_version(&db, &room_version) { + if services().rooms.is_supported_version(&services(), &room_version) { room_version } else { return Err(Error::BadRequest( @@ -109,7 +108,7 @@ pub async fn create_room_route( )); } } - None => db.globals.default_room_version(), + None => services().globals.default_room_version(), }; let content = match &body.creation_content { @@ -163,7 +162,7 @@ pub async fn create_room_route( } // 1. The room create event - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&content).expect("event is valid, we just created it"), @@ -173,21 +172,20 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 2. Let the room creator join - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: Some(body.is_direct), third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -198,7 +196,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; @@ -240,7 +237,7 @@ pub async fn create_room_route( } } - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_content) @@ -251,13 +248,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 4. Canonical room alias if let Some(room_alias_id) = &alias { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCanonicalAlias, content: to_raw_value(&RoomCanonicalAliasEventContent { @@ -271,7 +267,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -279,7 +274,7 @@ pub async fn create_room_route( // 5. Events set by preset // 5.1 Join Rules - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomJoinRules, content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { @@ -294,12 +289,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.2 History Visibility - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomHistoryVisibility, content: to_raw_value(&RoomHistoryVisibilityEventContent::new( @@ -312,12 +306,11 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; // 5.3 Guest Access - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomGuestAccess, content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { @@ -331,7 +324,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; @@ -346,18 +338,18 @@ pub async fn create_room_route( pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); // Silently skip encryption events if they are not allowed - if pdu_builder.event_type == RoomEventType::RoomEncryption && !db.globals.allow_encryption() + if pdu_builder.event_type == RoomEventType::RoomEncryption && !services().globals.allow_encryption() { continue; } - db.rooms - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock)?; + services().rooms + .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)?; } // 7. Events implied by name and topic if let Some(name) = &body.name { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomName, content: to_raw_value(&RoomNameEventContent::new(Some(name.clone()))) @@ -368,13 +360,12 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } if let Some(topic) = &body.topic { - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTopic, content: to_raw_value(&RoomTopicEventContent { @@ -387,7 +378,6 @@ pub async fn create_room_route( }, sender_user, &room_id, - &db, &state_lock, )?; } @@ -395,22 +385,20 @@ pub async fn create_room_route( // 8. Events implied by invite (and TODO: invite_3pid) drop(state_lock); for user_id in &body.invite { - let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; + let _ = invite_helper(sender_user, user_id, &room_id, body.is_direct).await; } // Homeserver specific stuff if let Some(alias) = alias { - db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; + services().rooms.set_alias(&alias, Some(&room_id))?; } if body.visibility == room::Visibility::Public { - db.rooms.set_public(&room_id, true)?; + services().rooms.set_public(&room_id, true)?; } info!("{} created a room", sender_user); - db.flush()?; - Ok(create_room::v3::Response::new(room_id)) } @@ -420,12 +408,11 @@ pub async fn create_room_route( /// /// - You have to currently be joined to the room (TODO: Respect history visibility) pub async fn get_room_event_route( - db: DatabaseGuard, body: Ruma<get_room_event::v3::IncomingRequest>, ) -> Result<get_room_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -433,7 +420,7 @@ pub async fn get_room_event_route( } Ok(get_room_event::v3::Response { - event: db + event: services() .rooms .get_pdu(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))? @@ -447,12 +434,11 @@ pub async fn get_room_event_route( /// /// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable pub async fn get_room_aliases_route( - db: DatabaseGuard, body: Ruma<aliases::v3::IncomingRequest>, ) -> Result<aliases::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", @@ -460,7 +446,7 @@ pub async fn get_room_aliases_route( } Ok(aliases::v3::Response { - aliases: db + aliases: services() .rooms .room_aliases(&body.room_id) .filter_map(|a| a.ok()) @@ -479,12 +465,11 @@ pub async fn get_room_aliases_route( /// - Moves local aliases /// - Modifies old room power levels to prevent users from speaking pub async fn upgrade_room_route( - db: DatabaseGuard, body: Ruma<upgrade_room::v3::IncomingRequest>, ) -> Result<upgrade_room::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_supported_version(&db, &body.new_version) { + if !services().rooms.is_supported_version(&body.new_version) { return Err(Error::BadRequest( ErrorKind::UnsupportedRoomVersion, "This server does not support that room version.", @@ -492,12 +477,12 @@ pub async fn upgrade_room_route( } // Create a replacement room - let replacement_room = RoomId::new(db.globals.server_name()); - db.rooms - .get_or_create_shortroomid(&replacement_room, &db.globals)?; + let replacement_room = RoomId::new(services().globals.server_name()); + services().rooms + .get_or_create_shortroomid(&replacement_room)?; let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -508,7 +493,7 @@ pub async fn upgrade_room_route( // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Fail if the sender does not have the required permissions - let tombstone_event_id = db.rooms.build_and_append_pdu( + let tombstone_event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomTombstone, content: to_raw_value(&RoomTombstoneEventContent { @@ -522,14 +507,13 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; // Change lock to replacement room drop(state_lock); let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -540,7 +524,7 @@ pub async fn upgrade_room_route( // Get the old room creation event let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -588,7 +572,7 @@ pub async fn upgrade_room_route( )); } - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomCreate, content: to_raw_value(&create_event_content) @@ -599,21 +583,20 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; // Join the new room - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomMember, content: to_raw_value(&RoomMemberEventContent { membership: MembershipState::Join, - displayname: db.users.displayname(sender_user)?, - avatar_url: db.users.avatar_url(sender_user)?, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, is_direct: None, third_party_invite: None, - blurhash: db.users.blurhash(sender_user)?, + blurhash: services().users.blurhash(sender_user)?, reason: None, join_authorized_via_users_server: None, }) @@ -624,7 +607,6 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; @@ -643,12 +625,12 @@ pub async fn upgrade_room_route( // Replicate transferable state events to the new room for event_type in transferable_state_events { - let event_content = match db.rooms.room_state_get(&body.room_id, &event_type, "")? { + let event_content = match services().rooms.room_state_get(&body.room_id, &event_type, "")? { Some(v) => v.content.clone(), None => continue, // Skipping missing events. }; - db.rooms.build_and_append_pdu( + services().rooms.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: event_content, @@ -658,20 +640,19 @@ pub async fn upgrade_room_route( }, sender_user, &replacement_room, - &db, &state_lock, )?; } // Moves any local aliases to the new room - for alias in db.rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { - db.rooms - .set_alias(&alias, Some(&replacement_room), &db.globals)?; + for alias in services().rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { + services().rooms + .set_alias(&alias, Some(&replacement_room))?; } // Get the old room power levels let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .content @@ -685,7 +666,7 @@ pub async fn upgrade_room_route( power_levels_event_content.invite = new_level; // Modify the power levels in the old room to prevent sending of events and inviting new users - let _ = db.rooms.build_and_append_pdu( + let _ = services().rooms.build_and_append_pdu( PduBuilder { event_type: RoomEventType::RoomPowerLevels, content: to_raw_value(&power_levels_event_content) @@ -696,35 +677,12 @@ pub async fn upgrade_room_route( }, sender_user, &body.room_id, - &db, &state_lock, )?; drop(state_lock); - db.flush()?; - // Return the replacement room id Ok(upgrade_room::v3::Response { replacement_room }) } - /// Returns the room's version. - #[tracing::instrument(skip(self))] - pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> { - let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?; - - let create_event_content: Option<RoomCreateEventContent> = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - let room_version = create_event_content - .map(|create_event| create_event.room_version) - .ok_or_else(|| Error::BadDatabase("Invalid room version"))?; - Ok(room_version) - } - diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index 686e3b5..b7eecd5 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::api::client::{ error::ErrorKind, search::search_events::{ @@ -15,7 +15,6 @@ use std::collections::BTreeMap; /// /// - Only works if the user is currently joined to the room (TODO: Respect history visibility) pub async fn search_events_route( - db: DatabaseGuard, body: Ruma<search_events::v3::IncomingRequest>, ) -> Result<search_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -24,7 +23,7 @@ pub async fn search_events_route( let filter = &search_criteria.filter; let room_ids = filter.rooms.clone().unwrap_or_else(|| { - db.rooms + services().rooms .rooms_joined(sender_user) .filter_map(|r| r.ok()) .collect() @@ -35,14 +34,14 @@ pub async fn search_events_route( let mut searches = Vec::new(); for room_id in room_ids { - if !db.rooms.is_joined(sender_user, &room_id)? { + if !services().rooms.is_joined(sender_user, &room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You don't have permission to view this room.", )); } - if let Some(search) = db + if let Some(search) = services() .rooms .search_pdus(&room_id, &search_criteria.search_term)? { @@ -85,7 +84,7 @@ pub async fn search_events_route( start: None, }, rank: None, - result: db + result: services() .rooms .get_pdu_from_id(result)? .map(|pdu| pdu.to_room_event()), diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index c2a79ca..7feeb66 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,5 +1,5 @@ use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::{ api::client::{ error::ErrorKind, @@ -41,7 +41,6 @@ pub async fn get_login_types_route( /// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// supported login types. pub async fn login_route( - db: DatabaseGuard, body: Ruma<login::v3::IncomingRequest>, ) -> Result<login::v3::Response> { // Validate login method @@ -57,11 +56,11 @@ pub async fn login_route( return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); }; let user_id = - UserId::parse_with_server_name(username.to_owned(), db.globals.server_name()) + UserId::parse_with_server_name(username.to_owned(), services().globals.server_name()) .map_err(|_| { Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") })?; - let hash = db.users.password_hash(&user_id)?.ok_or(Error::BadRequest( + let hash = services().users.password_hash(&user_id)?.ok_or(Error::BadRequest( ErrorKind::Forbidden, "Wrong username or password.", ))?; @@ -85,7 +84,7 @@ pub async fn login_route( user_id } login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => { - if let Some(jwt_decoding_key) = db.globals.jwt_decoding_key() { + if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { let token = jsonwebtoken::decode::<Claims>( token, jwt_decoding_key, @@ -93,7 +92,7 @@ pub async fn login_route( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; let username = token.claims.sub; - UserId::parse_with_server_name(username, db.globals.server_name()).map_err( + UserId::parse_with_server_name(username, services().globals.server_name()).map_err( |_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."), )? } else { @@ -122,15 +121,15 @@ pub async fn login_route( // Determine if device_id was provided and exists in the db for this user let device_exists = body.device_id.as_ref().map_or(false, |device_id| { - db.users + services().users .all_device_ids(&user_id) .any(|x| x.as_ref().map_or(false, |v| v == device_id)) }); if device_exists { - db.users.set_token(&user_id, &device_id, &token)?; + services().users.set_token(&user_id, &device_id, &token)?; } else { - db.users.create_device( + services().users.create_device( &user_id, &device_id, &token, @@ -140,12 +139,10 @@ pub async fn login_route( info!("{} logged in", user_id); - db.flush()?; - Ok(login::v3::Response { user_id, access_token: token, - home_server: Some(db.globals.server_name().to_owned()), + home_server: Some(services().globals.server_name().to_owned()), device_id, well_known: None, }) @@ -160,15 +157,12 @@ pub async fn login_route( /// - Forgets to-device events /// - Triggers device list updates pub async fn logout_route( - db: DatabaseGuard, body: Ruma<logout::v3::Request>, ) -> Result<logout::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - db.users.remove_device(sender_user, sender_device)?; - - db.flush()?; + services().users.remove_device(sender_user, sender_device)?; Ok(logout::v3::Response::new()) } @@ -185,16 +179,13 @@ pub async fn logout_route( /// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html) /// from each device of this user. pub async fn logout_all_route( - db: DatabaseGuard, body: Ruma<logout_all::v3::Request>, ) -> Result<logout_all::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in db.users.all_device_ids(sender_user).flatten() { - db.users.remove_device(sender_user, &device_id)?; + for device_id in services().users.all_device_ids(sender_user).flatten() { + services().users.remove_device(sender_user, &device_id)?; } - db.flush()?; - Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 4df953c..4e8d594 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use crate::{ - database::DatabaseGuard, pdu::PduBuilder, Database, Error, Result, Ruma, RumaResponse, + Error, Result, Ruma, RumaResponse, services, service::pdu::PduBuilder, }; use ruma::{ api::client::{ @@ -27,13 +27,11 @@ use ruma::{ /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_key_route( - db: DatabaseGuard, body: Ruma<send_state_event::v3::IncomingRequest>, ) -> Result<send_state_event::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type, @@ -42,8 +40,6 @@ pub async fn send_state_event_for_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }) } @@ -56,13 +52,12 @@ pub async fn send_state_event_for_key_route( /// - Tries to send the event into the room, auth rules will determine if it is allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_empty_key_route( - db: DatabaseGuard, body: Ruma<send_state_event::v3::IncomingRequest>, ) -> Result<RumaResponse<send_state_event::v3::Response>> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); // Forbid m.room.encryption if encryption is disabled - if body.event_type == StateEventType::RoomEncryption && !db.globals.allow_encryption() { + if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { return Err(Error::BadRequest( ErrorKind::Forbidden, "Encryption has been disabled", @@ -70,7 +65,6 @@ pub async fn send_state_event_for_empty_key_route( } let event_id = send_state_event_for_key_helper( - &db, sender_user, &body.room_id, &body.event_type.to_string().into(), @@ -79,8 +73,6 @@ pub async fn send_state_event_for_empty_key_route( ) .await?; - db.flush()?; - let event_id = (*event_id).to_owned(); Ok(send_state_event::v3::Response { event_id }.into()) } @@ -91,7 +83,6 @@ pub async fn send_state_event_for_empty_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_route( - db: DatabaseGuard, body: Ruma<get_state_events::v3::IncomingRequest>, ) -> Result<get_state_events::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -99,9 +90,9 @@ pub async fn get_state_events_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -122,7 +113,7 @@ pub async fn get_state_events_route( } Ok(get_state_events::v3::Response { - room_state: db + room_state: services() .rooms .room_state_full(&body.room_id) .await? @@ -138,7 +129,6 @@ pub async fn get_state_events_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_key_route( - db: DatabaseGuard, body: Ruma<get_state_events_for_key::v3::IncomingRequest>, ) -> Result<get_state_events_for_key::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -146,9 +136,9 @@ pub async fn get_state_events_for_key_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -168,7 +158,7 @@ pub async fn get_state_events_for_key_route( )); } - let event = db + let event = services() .rooms .room_state_get(&body.room_id, &body.event_type, &body.state_key)? .ok_or(Error::BadRequest( @@ -188,7 +178,6 @@ pub async fn get_state_events_for_key_route( /// /// - If not joined: Only works if current room history visibility is world readable pub async fn get_state_events_for_empty_key_route( - db: DatabaseGuard, body: Ruma<get_state_events_for_key::v3::IncomingRequest>, ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); @@ -196,9 +185,9 @@ pub async fn get_state_events_for_empty_key_route( #[allow(clippy::blocks_in_if_conditions)] // Users not in the room should not be able to access the state unless history_visibility is // WorldReadable - if !db.rooms.is_joined(sender_user, &body.room_id)? + if !services().rooms.is_joined(sender_user, &body.room_id)? && !matches!( - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .map(|event| { serde_json::from_str(event.content.get()) @@ -218,7 +207,7 @@ pub async fn get_state_events_for_empty_key_route( )); } - let event = db + let event = services() .rooms .room_state_get(&body.room_id, &body.event_type, "")? .ok_or(Error::BadRequest( @@ -234,7 +223,6 @@ pub async fn get_state_events_for_empty_key_route( } async fn send_state_event_for_key_helper( - db: &Database, sender: &UserId, room_id: &RoomId, event_type: &StateEventType, @@ -255,8 +243,8 @@ async fn send_state_event_for_key_helper( } for alias in aliases { - if alias.server_name() != db.globals.server_name() - || db + if alias.server_name() != services().globals.server_name() + || services() .rooms .id_from_alias(&alias)? .filter(|room| room == room_id) // Make sure it's the right room @@ -272,7 +260,7 @@ async fn send_state_event_for_key_helper( } let mutex_state = Arc::clone( - db.globals + services().globals .roomid_mutex_state .write() .unwrap() @@ -281,7 +269,7 @@ async fn send_state_event_for_key_helper( ); let state_lock = mutex_state.lock().await; - let event_id = db.rooms.build_and_append_pdu( + let event_id = services().rooms.build_and_append_pdu( PduBuilder { event_type: event_type.to_string().into(), content: serde_json::from_str(json.json().get()).expect("content is valid json"), @@ -291,7 +279,6 @@ async fn send_state_event_for_key_helper( }, sender_user, room_id, - db, &state_lock, )?; diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 0c294b7..cc4ebf6 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Database, Error, Result, Ruma, RumaResponse}; +use crate::{Error, Result, Ruma, RumaResponse, services}; use ruma::{ api::client::{ filter::{IncomingFilterDefinition, LazyLoadOptions}, @@ -55,16 +55,13 @@ use tracing::error; /// - Sync is handled in an async task, multiple requests from the same device with the same /// `since` will be cached pub async fn sync_events_route( - db: DatabaseGuard, body: Ruma<sync_events::v3::IncomingRequest>, ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { let sender_user = body.sender_user.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated"); let body = body.body; - let arc_db = Arc::new(db); - - let mut rx = match arc_db + let mut rx = match services() .globals .sync_receivers .write() @@ -77,7 +74,6 @@ pub async fn sync_events_route( v.insert((body.since.to_owned(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -93,7 +89,6 @@ pub async fn sync_events_route( o.insert((body.since.clone(), rx.clone())); tokio::spawn(sync_helper_wrapper( - Arc::clone(&arc_db), sender_user.clone(), sender_device.clone(), body, @@ -127,7 +122,6 @@ pub async fn sync_events_route( } async fn sync_helper_wrapper( - db: Arc<DatabaseGuard>, sender_user: Box<UserId>, sender_device: Box<DeviceId>, body: sync_events::v3::IncomingRequest, @@ -136,7 +130,6 @@ async fn sync_helper_wrapper( let since = body.since.clone(); let r = sync_helper( - Arc::clone(&db), sender_user.clone(), sender_device.clone(), body, @@ -145,7 +138,7 @@ async fn sync_helper_wrapper( if let Ok((_, caching_allowed)) = r { if !caching_allowed { - match db + match services() .globals .sync_receivers .write() @@ -163,13 +156,10 @@ async fn sync_helper_wrapper( } } - drop(db); - let _ = tx.send(Some(r.map(|(r, _)| r))); } async fn sync_helper( - db: Arc<DatabaseGuard>, sender_user: Box<UserId>, sender_device: Box<DeviceId>, body: sync_events::v3::IncomingRequest, @@ -182,19 +172,19 @@ async fn sync_helper( }; // TODO: match body.set_presence { - db.rooms.edus.ping_presence(&sender_user)?; + services().rooms.edus.ping_presence(&sender_user)?; // Setup watchers, so if there's no response, we can wait for them - let watcher = db.watch(&sender_user, &sender_device); + let watcher = services().watch(&sender_user, &sender_device); - let next_batch = db.globals.current_count()?; + let next_batch = services().globals.current_count()?; let next_batch_string = next_batch.to_string(); // Load filter let filter = match body.filter { None => IncomingFilterDefinition::default(), Some(IncomingFilter::FilterDefinition(filter)) => filter, - Some(IncomingFilter::FilterId(filter_id)) => db + Some(IncomingFilter::FilterId(filter_id)) => services() .users .get_filter(&sender_user, &filter_id)? .unwrap_or_default(), @@ -221,12 +211,12 @@ async fn sync_helper( // Look for device list updates of this account device_list_updates.extend( - db.users + services().users .keys_changed(&sender_user.to_string(), since, None) .filter_map(|r| r.ok()), ); - let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); + let all_joined_rooms = services().rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); for room_id in all_joined_rooms { let room_id = room_id?; @@ -234,7 +224,7 @@ async fn sync_helper( // Get and drop the lock to wait for remaining operations to finish // This will make sure the we have all events until next_batch let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -247,8 +237,8 @@ async fn sync_helper( let timeline_pdus; let limited; - if db.rooms.last_timeline_count(&sender_user, &room_id)? > since { - let mut non_timeline_pdus = db + if services().rooms.last_timeline_count(&sender_user, &room_id)? > since { + let mut non_timeline_pdus = services() .rooms .pdus_until(&sender_user, &room_id, u64::MAX)? .filter_map(|r| { @@ -259,7 +249,7 @@ async fn sync_helper( r.ok() }) .take_while(|(pduid, _)| { - db.rooms + services().rooms .pdu_count(pduid) .map_or(false, |count| count > since) }); @@ -282,7 +272,7 @@ async fn sync_helper( } let send_notification_counts = !timeline_pdus.is_empty() - || db + || services() .rooms .edus .last_privateread_update(&sender_user, &room_id)? @@ -293,24 +283,24 @@ async fn sync_helper( timeline_users.insert(event.sender.as_str().to_owned()); } - db.rooms + services().rooms .lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?; // Database queries: - let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { + let current_shortstatehash = if let Some(s) = services().rooms.current_shortstatehash(&room_id)? { s } else { error!("Room {} has no state", room_id); continue; }; - let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = services().rooms.get_token_shortstatehash(&room_id, since)?; // Calculates joined_member_count, invited_member_count and heroes let calculate_counts = || { - let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); - let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); + let joined_member_count = services().rooms.room_joined_count(&room_id)?.unwrap_or(0); + let invited_member_count = services().rooms.room_invited_count(&room_id)?.unwrap_or(0); // Recalculate heroes (first 5 members) let mut heroes = Vec::new(); @@ -319,7 +309,7 @@ async fn sync_helper( // Go through all PDUs and for each member event, check if the user is still joined or // invited until we have 5 or we reach the end - for hero in db + for hero in services() .rooms .all_pdus(&sender_user, &room_id)? .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus @@ -339,8 +329,8 @@ async fn sync_helper( if matches!( content.membership, MembershipState::Join | MembershipState::Invite - ) && (db.rooms.is_joined(&user_id, &room_id)? - || db.rooms.is_invited(&user_id, &room_id)?) + ) && (services().rooms.is_joined(&user_id, &room_id)? + || services().rooms.is_invited(&user_id, &room_id)?) { Ok::<_, Error>(Some(state_key.clone())) } else { @@ -381,17 +371,17 @@ async fn sync_helper( let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; + let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?; let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); let mut i = 0; for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; + let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?; if event_type != StateEventType::RoomMember { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -408,7 +398,7 @@ async fn sync_helper( || body.full_state || timeline_users.contains(&state_key) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -430,12 +420,12 @@ async fn sync_helper( } // Reset lazy loading because this is an initial sync - db.rooms + services().rooms .lazy_load_reset(&sender_user, &sender_device, &room_id)?; // The state_events above should contain all timeline_users, let's mark them as lazy // loaded. - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -457,7 +447,7 @@ async fn sync_helper( // Incremental /sync let since_shortstatehash = since_shortstatehash.unwrap(); - let since_sender_member: Option<RoomMemberEventContent> = db + let since_sender_member: Option<RoomMemberEventContent> = services() .rooms .state_get( since_shortstatehash, @@ -477,12 +467,12 @@ async fn sync_helper( let mut lazy_loaded = HashSet::new(); if since_shortstatehash != current_shortstatehash { - let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; - let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; + let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?; + let since_state_ids = services().rooms.state_full_ids(since_shortstatehash).await?; for (key, id) in current_state_ids { if body.full_state || since_state_ids.get(&key) != Some(&id) { - let pdu = match db.rooms.get_pdu(&id)? { + let pdu = match services().rooms.get_pdu(&id)? { Some(pdu) => pdu, None => { error!("Pdu in state not found: {}", id); @@ -515,14 +505,14 @@ async fn sync_helper( continue; } - if !db.rooms.lazy_load_was_sent_before( + if !services().rooms.lazy_load_was_sent_before( &sender_user, &sender_device, &room_id, &event.sender, )? || lazy_load_send_redundant { - if let Some(member_event) = db.rooms.room_state_get( + if let Some(member_event) = services().rooms.room_state_get( &room_id, &StateEventType::RoomMember, event.sender.as_str(), @@ -533,7 +523,7 @@ async fn sync_helper( } } - db.rooms.lazy_load_mark_sent( + services().rooms.lazy_load_mark_sent( &sender_user, &sender_device, &room_id, @@ -541,13 +531,13 @@ async fn sync_helper( next_batch, ); - let encrypted_room = db + let encrypted_room = services() .rooms .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .is_some(); let since_encryption = - db.rooms + services().rooms .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; // Calculations: @@ -580,7 +570,7 @@ async fn sync_helper( match new_membership { MembershipState::Join => { // A new user joined an encrypted room - if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { + if !share_encrypted_room(&sender_user, &user_id, &room_id)? { device_list_updates.insert(user_id); } } @@ -597,7 +587,7 @@ async fn sync_helper( if joined_since_last_sync && encrypted_room || new_encrypted_room { // If the user is in a new encrypted room, give them all joined users device_list_updates.extend( - db.rooms + services().rooms .room_members(&room_id) .flatten() .filter(|user_id| { @@ -606,7 +596,7 @@ async fn sync_helper( }) .filter(|user_id| { // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&db, &sender_user, user_id, &room_id) + !share_encrypted_room(&sender_user, user_id, &room_id) .unwrap_or(false) }), ); @@ -629,14 +619,14 @@ async fn sync_helper( // Look for device list updates in this room device_list_updates.extend( - db.users + services().users .keys_changed(&room_id.to_string(), since, None) .filter_map(|r| r.ok()), ); let notification_count = if send_notification_counts { Some( - db.rooms + services().rooms .notification_count(&sender_user, &room_id)? .try_into() .expect("notification count can't go that high"), @@ -647,7 +637,7 @@ async fn sync_helper( let highlight_count = if send_notification_counts { Some( - db.rooms + services().rooms .highlight_count(&sender_user, &room_id)? .try_into() .expect("highlight count can't go that high"), @@ -659,7 +649,7 @@ async fn sync_helper( let prev_batch = timeline_pdus .first() .map_or(Ok::<_, Error>(None), |(pdu_id, _)| { - Ok(Some(db.rooms.pdu_count(pdu_id)?.to_string())) + Ok(Some(services().rooms.pdu_count(pdu_id)?.to_string())) })?; let room_events: Vec<_> = timeline_pdus @@ -667,7 +657,7 @@ async fn sync_helper( .map(|(_, pdu)| pdu.to_sync_room_event()) .collect(); - let mut edus: Vec<_> = db + let mut edus: Vec<_> = services() .rooms .edus .readreceipts_since(&room_id, since) @@ -675,10 +665,10 @@ async fn sync_helper( .map(|(_, _, v)| v) .collect(); - if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since { + if services().rooms.edus.last_typing_update(&room_id, &services().globals)? > since { edus.push( serde_json::from_str( - &serde_json::to_string(&db.rooms.edus.typings_all(&room_id)?) + &serde_json::to_string(&services().rooms.edus.typings_all(&room_id)?) .expect("event is valid, we just created it"), ) .expect("event is valid, we just created it"), @@ -686,12 +676,12 @@ async fn sync_helper( } // Save the state after this sync so we can send the correct state diff next sync - db.rooms + services().rooms .associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?; let joined_room = JoinedRoom { account_data: RoomAccountData { - events: db + events: services() .account_data .changes_since(Some(&room_id), &sender_user, since)? .into_iter() @@ -731,9 +721,9 @@ async fn sync_helper( // Take presence updates from this room for (user_id, presence) in - db.rooms + services().rooms .edus - .presence_since(&room_id, since, &db.rooms, &db.globals)? + .presence_since(&room_id, since)? { match presence_updates.entry(user_id) { Entry::Vacant(v) => { @@ -765,14 +755,14 @@ async fn sync_helper( } let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = db.rooms.rooms_left(&sender_user).collect(); + let all_left_rooms: Vec<_> = services().rooms.rooms_left(&sender_user).collect(); for result in all_left_rooms { let (room_id, left_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -783,7 +773,7 @@ async fn sync_helper( drop(insert_lock); } - let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; + let left_count = services().rooms.get_left_count(&room_id, &sender_user)?; // Left before last sync if Some(since) >= left_count { @@ -807,14 +797,14 @@ async fn sync_helper( } let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = db.rooms.rooms_invited(&sender_user).collect(); + let all_invited_rooms: Vec<_> = services().rooms.rooms_invited(&sender_user).collect(); for result in all_invited_rooms { let (room_id, invite_state_events) = result?; { // Get and drop the lock to wait for remaining operations to finish let mutex_insert = Arc::clone( - db.globals + services().globals .roomid_mutex_insert .write() .unwrap() @@ -825,7 +815,7 @@ async fn sync_helper( drop(insert_lock); } - let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; + let invite_count = services().rooms.get_invite_count(&room_id, &sender_user)?; // Invited before last sync if Some(since) >= invite_count { @@ -843,13 +833,13 @@ async fn sync_helper( } for user_id in left_encrypted_users { - let still_share_encrypted_room = db + let still_share_encrypted_room = services() .rooms .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .filter_map(|r| r.ok()) .filter_map(|other_room_id| { Some( - db.rooms + services().rooms .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), @@ -864,7 +854,7 @@ async fn sync_helper( } // Remove all to-device events the device received *last time* - db.users + services().users .remove_to_device_events(&sender_user, &sender_device, since)?; let response = sync_events::v3::Response { @@ -882,7 +872,7 @@ async fn sync_helper( .collect(), }, account_data: GlobalAccountData { - events: db + events: services() .account_data .changes_since(None, &sender_user, since)? .into_iter() @@ -897,9 +887,9 @@ async fn sync_helper( changed: device_list_updates.into_iter().collect(), left: device_list_left.into_iter().collect(), }, - device_one_time_keys_count: db.users.count_one_time_keys(&sender_user, &sender_device)?, + device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, to_device: ToDevice { - events: db + events: services() .users .get_to_device_events(&sender_user, &sender_device)?, }, @@ -928,21 +918,19 @@ async fn sync_helper( } } -#[tracing::instrument(skip(db))] fn share_encrypted_room( - db: &Database, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId, ) -> Result<bool> { - Ok(db + Ok(services() .rooms .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? .filter_map(|r| r.ok()) .filter(|room_id| room_id != ignore_room) .filter_map(|other_room_id| { Some( - db.rooms + services().rooms .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .ok()? .is_some(), diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 98d895c..bbea2d5 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::{ api::client::tag::{create_tag, delete_tag, get_tags}, events::{ @@ -14,12 +14,11 @@ use std::collections::BTreeMap; /// /// - Inserts the tag into the tag event of the room account data. pub async fn update_tag_route( - db: DatabaseGuard, body: Ruma<create_tag::v3::IncomingRequest>, ) -> Result<create_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = db + let mut tags_event = services() .account_data .get( Some(&body.room_id), @@ -36,16 +35,13 @@ pub async fn update_tag_route( .tags .insert(body.tag.clone().into(), body.tag_info.clone()); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, &tags_event, - &db.globals, )?; - db.flush()?; - Ok(create_tag::v3::Response {}) } @@ -55,12 +51,11 @@ pub async fn update_tag_route( /// /// - Removes the tag from the tag event of the room account data. pub async fn delete_tag_route( - db: DatabaseGuard, body: Ruma<delete_tag::v3::IncomingRequest>, ) -> Result<delete_tag::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut tags_event = db + let mut tags_event = services() .account_data .get( Some(&body.room_id), @@ -74,16 +69,13 @@ pub async fn delete_tag_route( }); tags_event.content.tags.remove(&body.tag.clone().into()); - db.account_data.update( + services().account_data.update( Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag, &tags_event, - &db.globals, )?; - db.flush()?; - Ok(delete_tag::v3::Response {}) } @@ -93,13 +85,12 @@ pub async fn delete_tag_route( /// /// - Gets the tag event of the room account data. pub async fn get_tags_route( - db: DatabaseGuard, body: Ruma<get_tags::v3::IncomingRequest>, ) -> Result<get_tags::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); Ok(get_tags::v3::Response { - tags: db + tags: services() .account_data .get( Some(&body.room_id), diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index 51441dd..3a2f6c0 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -1,7 +1,7 @@ use ruma::events::ToDeviceEventType; use std::collections::BTreeMap; -use crate::{database::DatabaseGuard, Error, Result, Ruma}; +use crate::{Error, Result, Ruma, services}; use ruma::{ api::{ client::{error::ErrorKind, to_device::send_event_to_device}, @@ -14,14 +14,13 @@ use ruma::{ /// /// Send a to-device event to a set of client devices. pub async fn send_event_to_device_route( - db: DatabaseGuard, body: Ruma<send_event_to_device::v3::IncomingRequest>, ) -> Result<send_event_to_device::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_deref(); // Check if this is a new transaction id - if db + if services() .transaction_ids .existing_txnid(sender_user, sender_device, &body.txn_id)? .is_some() @@ -31,13 +30,13 @@ pub async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != db.globals.server_name() { + if target_user_id.server_name() != services().globals.server_name() { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); messages.insert(target_user_id.clone(), map); - db.sending.send_reliable_edu( + services().sending.send_reliable_edu( target_user_id.server_name(), serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( DirectDeviceContent { @@ -48,14 +47,14 @@ pub async fn send_event_to_device_route( }, )) .expect("DirectToDevice EDU can be serialized"), - db.globals.next_count()?, + services().globals.next_count()?, )?; continue; } match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event( + DeviceIdOrAllDevices::DeviceId(target_device_id) => services().users.add_to_device_event( sender_user, target_user_id, &target_device_id, @@ -63,12 +62,11 @@ pub async fn send_event_to_device_route( event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - &db.globals, )?, DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( sender_user, target_user_id, &target_device_id?, @@ -76,7 +74,6 @@ pub async fn send_event_to_device_route( event.deserialize_as().map_err(|_| { Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") })?, - &db.globals, )?; } } @@ -85,10 +82,8 @@ pub async fn send_event_to_device_route( } // Save transaction id with empty data - db.transaction_ids + services().transaction_ids .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - db.flush()?; - Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client_server/typing.rs b/src/api/client_server/typing.rs index cac5a5f..afd5d6b 100644 --- a/src/api/client_server/typing.rs +++ b/src/api/client_server/typing.rs @@ -1,18 +1,17 @@ -use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; +use crate::{utils, Error, Result, Ruma, services}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. pub async fn create_typing_event_route( - db: DatabaseGuard, body: Ruma<create_typing_event::v3::IncomingRequest>, ) -> Result<create_typing_event::v3::Response> { use create_typing_event::v3::Typing; let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !db.rooms.is_joined(sender_user, &body.room_id)? { + if !services().rooms.is_joined(sender_user, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "You are not in this room.", @@ -20,16 +19,15 @@ pub async fn create_typing_event_route( } if let Typing::Yes(duration) = body.state { - db.rooms.edus.typing_add( + services().rooms.edus.typing_add( sender_user, &body.room_id, duration.as_millis() as u64 + utils::millis_since_unix_epoch(), - &db.globals, )?; } else { - db.rooms + services().rooms .edus - .typing_remove(sender_user, &body.room_id, &db.globals)?; + .typing_remove(sender_user, &body.room_id)?; } Ok(create_typing_event::v3::Response {}) diff --git a/src/api/client_server/user_directory.rs b/src/api/client_server/user_directory.rs index 349c139..60b4e2f 100644 --- a/src/api/client_server/user_directory.rs +++ b/src/api/client_server/user_directory.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use ruma::{ api::client::user_directory::search_users, events::{ @@ -14,20 +14,19 @@ use ruma::{ /// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// and don't share a room with the sender pub async fn search_users_route( - db: DatabaseGuard, body: Ruma<search_users::v3::IncomingRequest>, ) -> Result<search_users::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let limit = u64::from(body.limit) as usize; - let mut users = db.users.iter().filter_map(|user_id| { + let mut users = services().users.iter().filter_map(|user_id| { // Filter out buggy users (they should not exist, but you never know...) let user_id = user_id.ok()?; let user = search_users::v3::User { user_id: user_id.clone(), - display_name: db.users.displayname(&user_id).ok()?, - avatar_url: db.users.avatar_url(&user_id).ok()?, + display_name: services().users.displayname(&user_id).ok()?, + avatar_url: services().users.avatar_url(&user_id).ok()?, }; let user_id_matches = user @@ -50,11 +49,11 @@ pub async fn search_users_route( } let user_is_in_public_rooms = - db.rooms + services().rooms .rooms_joined(&user_id) .filter_map(|r| r.ok()) .any(|room| { - db.rooms + services().rooms .room_state_get(&room, &StateEventType::RoomJoinRules, "") .map_or(false, |event| { event.map_or(false, |event| { @@ -70,7 +69,7 @@ pub async fn search_users_route( return Some(user); } - let user_is_in_shared_rooms = db + let user_is_in_shared_rooms = services() .rooms .get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) .ok()? diff --git a/src/api/client_server/voip.rs b/src/api/client_server/voip.rs index 7e9de31..2a804f9 100644 --- a/src/api/client_server/voip.rs +++ b/src/api/client_server/voip.rs @@ -1,4 +1,4 @@ -use crate::{database::DatabaseGuard, Result, Ruma}; +use crate::{Result, Ruma, services}; use hmac::{Hmac, Mac, NewMac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use sha1::Sha1; @@ -10,16 +10,15 @@ type HmacSha1 = Hmac<Sha1>; /// /// TODO: Returns information about the recommended turn server. pub async fn turn_server_route( - db: DatabaseGuard, body: Ruma<get_turn_server_info::v3::IncomingRequest>, ) -> Result<get_turn_server_info::v3::Response> { let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let turn_secret = db.globals.turn_secret(); + let turn_secret = services().globals.turn_secret(); let (username, password) = if !turn_secret.is_empty() { let expiry = SecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(db.globals.turn_ttl()), + SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), ) .expect("time is valid"); @@ -34,15 +33,15 @@ pub async fn turn_server_route( (username, password) } else { ( - db.globals.turn_username().clone(), - db.globals.turn_password().clone(), + services().globals.turn_username().clone(), + services().globals.turn_password().clone(), ) }; Ok(get_turn_server_info::v3::Response { username, password, - uris: db.globals.turn_uris().to_vec(), - ttl: Duration::from_secs(db.globals.turn_ttl()), + uris: services().globals.turn_uris().to_vec(), + ttl: Duration::from_secs(services().globals.turn_ttl()), }) } diff --git a/src/api/mod.rs b/src/api/mod.rs new file mode 100644 index 0000000..68589be --- /dev/null +++ b/src/api/mod.rs @@ -0,0 +1,4 @@ +pub mod client_server; +pub mod server_server; +pub mod appservice_server; +pub mod ruma_wrapper; diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 45e9d9a..babf2a7 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -24,7 +24,7 @@ use serde::Deserialize; use tracing::{debug, error, warn}; use super::{Ruma, RumaResponse}; -use crate::{database::DatabaseGuard, server_server, Error, Result}; +use crate::{Error, Result, api::server_server, services}; #[async_trait] impl<T, B> FromRequest<B> for Ruma<T> @@ -44,7 +44,6 @@ where } let metadata = T::METADATA; - let db = DatabaseGuard::from_request(req).await?; let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; let path_params = Path::<Vec<String>>::from_request(req).await?; @@ -71,7 +70,7 @@ where let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); - let appservices = db.appservice.all().unwrap(); + let appservices = services().appservice.all().unwrap(); let appservice_registration = appservices.iter().find(|(_id, registration)| { registration .get("as_token") @@ -91,14 +90,14 @@ where .unwrap() .as_str() .unwrap(), - db.globals.server_name(), + services().globals.server_name(), ) .unwrap() }, |s| UserId::parse(s).unwrap(), ); - if !db.users.exists(&user_id).unwrap() { + if !services().users.exists(&user_id).unwrap() { return Err(Error::BadRequest( ErrorKind::Forbidden, "User does not exist.", @@ -124,7 +123,7 @@ where } }; - match db.users.find_from_token(token).unwrap() { + match services().users.find_from_token(token).unwrap() { None => { return Err(Error::BadRequest( ErrorKind::UnknownToken { soft_logout: false }, @@ -185,7 +184,7 @@ where ( "destination".to_owned(), CanonicalJsonValue::String( - db.globals.server_name().as_str().to_owned(), + services().globals.server_name().as_str().to_owned(), ), ), ( @@ -199,7 +198,6 @@ where }; let keys_result = server_server::fetch_signing_keys( - &db, &x_matrix.origin, vec![x_matrix.key.to_owned()], ) @@ -251,7 +249,7 @@ where if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", db.globals.server_name()) + UserId::parse_with_server_name("", services().globals.server_name()) .expect("we know this is valid") }); @@ -261,7 +259,7 @@ where .and_then(|auth| auth.get("session")) .and_then(|session| session.as_str()) .and_then(|session| { - db.uiaa.get_uiaa_request( + services().uiaa.get_uiaa_request( &user_id, &sender_device.clone().unwrap_or_else(|| "".into()), session, diff --git a/src/api/server_server.rs b/src/api/server_server.rs index f60f735..776777d 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1,8 +1,6 @@ use crate::{ - client_server::{self, claim_keys_helper, get_keys_helper}, - database::{rooms::CompressedStateEvent, DatabaseGuard}, - pdu::EventHash, - utils, Database, Error, PduEvent, Result, Ruma, + api::client_server::{self, claim_keys_helper, get_keys_helper}, + utils, Error, PduEvent, Result, Ruma, services, service::pdu::{gen_event_id_canonical_json, PduBuilder}, }; use axum::{response::IntoResponse, Json}; use futures_util::{stream::FuturesUnordered, StreamExt}; @@ -126,22 +124,21 @@ impl FedDest { } } -#[tracing::instrument(skip(globals, request))] +#[tracing::instrument(skip(request))] pub(crate) async fn send_request<T: OutgoingRequest>( - globals: &crate::database::globals::Globals, destination: &ServerName, request: T, ) -> Result<T::IncomingResponse> where T: Debug, { - if !globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let mut write_destination_to_cache = false; - let cached_result = globals + let cached_result = services().globals .actual_destination_cache .read() .unwrap() @@ -153,7 +150,7 @@ where } else { write_destination_to_cache = true; - let result = find_actual_destination(globals, destination).await; + let result = find_actual_destination(destination).await; (result.0, result.1.into_uri_string()) }; @@ -194,15 +191,15 @@ where .to_string() .into(), ); - request_map.insert("origin".to_owned(), globals.server_name().as_str().into()); + request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); request_map.insert("destination".to_owned(), destination.as_str().into()); let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); ruma::signatures::sign_json( - globals.server_name().as_str(), - globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut request_json, ) .expect("our request json is what ruma expects"); @@ -227,7 +224,7 @@ where AUTHORIZATION, HeaderValue::from_str(&format!( "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - globals.server_name(), + services().globals.server_name(), s.0, s.1 )) @@ -241,7 +238,7 @@ where let url = reqwest_request.url().clone(); - let response = globals.federation_client().execute(reqwest_request).await; + let response = services().globals.federation_client().execute(reqwest_request).await; match response { Ok(mut response) => { @@ -281,7 +278,7 @@ where if status == 200 { let response = T::IncomingResponse::try_from_http_response(http_response); if response.is_ok() && write_destination_to_cache { - globals.actual_destination_cache.write().unwrap().insert( + services().globals.actual_destination_cache.write().unwrap().insert( Box::<ServerName>::from(destination), (actual_destination, host), ); @@ -332,9 +329,7 @@ fn add_port_to_hostname(destination_str: &str) -> FedDest { /// Returns: actual_destination, host header /// Implemented according to the specification at https://matrix.org/docs/spec/server_server/r0.1.4#resolving-server-names /// Numbers in comments below refer to bullet points in linked section of specification -#[tracing::instrument(skip(globals))] async fn find_actual_destination( - globals: &crate::database::globals::Globals, destination: &'_ ServerName, ) -> (FedDest, FedDest) { let destination_str = destination.as_str().to_owned(); @@ -350,7 +345,7 @@ async fn find_actual_destination( let (host, port) = destination_str.split_at(pos); FedDest::Named(host.to_owned(), port.to_owned()) } else { - match request_well_known(globals, destination.as_str()).await { + match request_well_known(destination.as_str()).await { // 3: A .well-known file is available Some(delegated_hostname) => { hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); @@ -364,17 +359,17 @@ async fn find_actual_destination( } else { // Delegated hostname has no port in this branch if let Some(hostname_override) = - query_srv_record(globals, &delegated_hostname).await + query_srv_record(&delegated_hostname).await { // 3.3: SRV lookup successful let force_port = hostname_override.port(); - if let Ok(override_ip) = globals + if let Ok(override_ip) = services().globals .dns_resolver() .lookup_ip(hostname_override.hostname()) .await { - globals.tls_name_override.write().unwrap().insert( + services().globals.tls_name_override.write().unwrap().insert( delegated_hostname.clone(), ( override_ip.iter().collect(), @@ -400,17 +395,17 @@ async fn find_actual_destination( } // 4: No .well-known or an error occured None => { - match query_srv_record(globals, &destination_str).await { + match query_srv_record(&destination_str).await { // 4: SRV record found Some(hostname_override) => { let force_port = hostname_override.port(); - if let Ok(override_ip) = globals + if let Ok(override_ip) = services().globals .dns_resolver() .lookup_ip(hostname_override.hostname()) .await { - globals.tls_name_override.write().unwrap().insert( + services().globals.tls_name_override.write().unwrap().insert( hostname.clone(), (override_ip.iter().collect(), force_port.unwrap_or(8448)), ); @@ -448,12 +443,10 @@ async fn find_actual_destination( (actual_destination, hostname) } -#[tracing::instrument(skip(globals))] async fn query_srv_record( - globals: &crate::database::globals::Globals, hostname: &'_ str, ) -> Option<FedDest> { - if let Ok(Some(host_port)) = globals + if let Ok(Some(host_port)) = services().globals .dns_resolver() .srv_lookup(format!("_matrix._tcp.{}", hostname)) .await @@ -472,13 +465,11 @@ async fn query_srv_record( } } -#[tracing::instrument(skip(globals))] async fn request_well_known( - globals: &crate::database::globals::Globals, destination: &str, ) -> Option<String> { let body: serde_json::Value = serde_json::from_str( - &globals + &services().globals .default_client() .get(&format!( "https://{}/.well-known/matrix/server", @@ -499,10 +490,9 @@ async fn request_well_known( /// /// Get version information on this server. pub async fn get_server_version_route( - db: DatabaseGuard, _body: Ruma<get_server_version::v1::Request>, ) -> Result<get_server_version::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -521,24 +511,24 @@ pub async fn get_server_version_route( /// - Matrix does not support invalidating public keys, so the key returned by this will be valid /// forever. // Response type for this endpoint is Json because we need to calculate a signature for the response -pub async fn get_server_keys_route(db: DatabaseGuard) -> Result<impl IntoResponse> { - if !db.globals.allow_federation() { +pub async fn get_server_keys_route() -> Result<impl IntoResponse> { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let mut verify_keys: BTreeMap<Box<ServerSigningKeyId>, VerifyKey> = BTreeMap::new(); verify_keys.insert( - format!("ed25519:{}", db.globals.keypair().version()) + format!("ed25519:{}", services().globals.keypair().version()) .try_into() .expect("found invalid server signing keys in DB"), VerifyKey { - key: Base64::new(db.globals.keypair().public_key().to_vec()), + key: Base64::new(services().globals.keypair().public_key().to_vec()), }, ); let mut response = serde_json::from_slice( get_server_keys::v2::Response { server_key: Raw::new(&ServerSigningKeys { - server_name: db.globals.server_name().to_owned(), + server_name: services().globals.server_name().to_owned(), verify_keys, old_verify_keys: BTreeMap::new(), signatures: BTreeMap::new(), @@ -556,8 +546,8 @@ pub async fn get_server_keys_route(db: DatabaseGuard) -> Result<impl IntoRespons .unwrap(); ruma::signatures::sign_json( - db.globals.server_name().as_str(), - db.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut response, ) .unwrap(); @@ -571,23 +561,21 @@ pub async fn get_server_keys_route(db: DatabaseGuard) -> Result<impl IntoRespons /// /// - Matrix does not support invalidating public keys, so the key returned by this will be valid /// forever. -pub async fn get_server_keys_deprecated_route(db: DatabaseGuard) -> impl IntoResponse { - get_server_keys_route(db).await +pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { + get_server_keys_route().await } /// # `POST /_matrix/federation/v1/publicRooms` /// /// Lists the public rooms on this server. pub async fn get_public_rooms_filtered_route( - db: DatabaseGuard, body: Ruma<get_public_rooms_filtered::v1::IncomingRequest>, ) -> Result<get_public_rooms_filtered::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let response = client_server::get_public_rooms_filtered_helper( - &db, None, body.limit, body.since.as_deref(), @@ -608,15 +596,13 @@ pub async fn get_public_rooms_filtered_route( /// /// Lists the public rooms on this server. pub async fn get_public_rooms_route( - db: DatabaseGuard, body: Ruma<get_public_rooms::v1::IncomingRequest>, ) -> Result<get_public_rooms::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } let response = client_server::get_public_rooms_filtered_helper( - &db, None, body.limit, body.since.as_deref(), @@ -637,10 +623,9 @@ pub async fn get_public_rooms_route( /// /// Push EDUs and PDUs to this server. pub async fn send_transaction_message_route( - db: DatabaseGuard, body: Ruma<send_transaction_message::v1::IncomingRequest>, ) -> Result<send_transaction_message::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -663,7 +648,7 @@ pub async fn send_transaction_message_route( for pdu in &body.pdus { // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { + let (event_id, value) = match gen_event_id_canonical_json(pdu) { Ok(t) => t, Err(_) => { // Event could not be converted to canonical json @@ -684,10 +669,10 @@ pub async fn send_transaction_message_route( } }; - acl_check(&sender_servername, &room_id, &db)?; + acl_check(&sender_servername, &room_id)?; let mutex = Arc::clone( - db.globals + services().globals .roomid_mutex_federation .write() .unwrap() @@ -698,13 +683,12 @@ pub async fn send_transaction_message_route( let start_time = Instant::now(); resolved_map.insert( event_id.clone(), - handle_incoming_pdu( + services().rooms.event_handler.handle_incoming_pdu( &sender_servername, &event_id, &room_id, value, true, - &db, &pub_key_map, ) .await @@ -743,7 +727,7 @@ pub async fn send_transaction_message_route( .event_ids .iter() .filter_map(|id| { - db.rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) + services().rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) }) .max_by_key(|(_, count)| *count) { @@ -760,11 +744,10 @@ pub async fn send_transaction_message_route( content: ReceiptEventContent(receipt_content), room_id: room_id.clone(), }; - db.rooms.edus.readreceipt_update( + services().rooms.edus.readreceipt_update( &user_id, &room_id, event, - &db.globals, )?; } else { // TODO fetch missing events @@ -774,26 +757,24 @@ pub async fn send_transaction_message_route( } } Edu::Typing(typing) => { - if db.rooms.is_joined(&typing.user_id, &typing.room_id)? { + if services().rooms.is_joined(&typing.user_id, &typing.room_id)? { if typing.typing { - db.rooms.edus.typing_add( + services().rooms.edus.typing_add( &typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch(), - &db.globals, )?; } else { - db.rooms.edus.typing_remove( + services().rooms.edus.typing_remove( &typing.user_id, &typing.room_id, - &db.globals, )?; } } } Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { - db.users - .mark_device_key_update(&user_id, &db.rooms, &db.globals)?; + services().users + .mark_device_key_update(&user_id)?; } Edu::DirectToDevice(DirectDeviceContent { sender, @@ -802,7 +783,7 @@ pub async fn send_transaction_message_route( messages, }) => { // Check if this is a new transaction id - if db + if services() .transaction_ids .existing_txnid(&sender, None, &message_id)? .is_some() @@ -814,7 +795,7 @@ pub async fn send_transaction_message_route( for (target_device_id_maybe, event) in map { match target_device_id_maybe { DeviceIdOrAllDevices::DeviceId(target_device_id) => { - db.users.add_to_device_event( + services().users.add_to_device_event( &sender, target_user_id, target_device_id, @@ -825,13 +806,12 @@ pub async fn send_transaction_message_route( "Event is invalid", ) })?, - &db.globals, )? } DeviceIdOrAllDevices::AllDevices => { - for target_device_id in db.users.all_device_ids(target_user_id) { - db.users.add_to_device_event( + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( &sender, target_user_id, &target_device_id?, @@ -842,7 +822,6 @@ pub async fn send_transaction_message_route( "Event is invalid", ) })?, - &db.globals, )?; } } @@ -851,7 +830,7 @@ pub async fn send_transaction_message_route( } // Save transaction id with empty data - db.transaction_ids + services().transaction_ids .add_txnid(&sender, None, &message_id, &[])?; } Edu::SigningKeyUpdate(SigningKeyUpdateContent { @@ -863,13 +842,11 @@ pub async fn send_transaction_message_route( continue; } if let Some(master_key) = master_key { - db.users.add_cross_signing_keys( + services().users.add_cross_signing_keys( &user_id, &master_key, &self_signing_key, &None, - &db.rooms, - &db.globals, )?; } } @@ -877,8 +854,6 @@ pub async fn send_transaction_message_route( } } - db.flush()?; - Ok(send_transaction_message::v1::Response { pdus: resolved_map }) } @@ -886,14 +861,13 @@ pub async fn send_transaction_message_route( /// fetch them from the server and save to our DB. #[tracing::instrument(skip_all)] pub(crate) async fn fetch_signing_keys( - db: &Database, origin: &ServerName, signature_ids: Vec<String>, ) -> Result<BTreeMap<String, Base64>> { let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - let permit = db + let permit = services() .globals .servername_ratelimiter .read() @@ -904,7 +878,7 @@ pub(crate) async fn fetch_signing_keys( let permit = match permit { Some(p) => p, None => { - let mut write = db.globals.servername_ratelimiter.write().unwrap(); + let mut write = services().globals.servername_ratelimiter.write().unwrap(); let s = Arc::clone( write .entry(origin.to_owned()) @@ -916,7 +890,7 @@ pub(crate) async fn fetch_signing_keys( } .await; - let back_off = |id| match db + let back_off = |id| match services() .globals .bad_signature_ratelimiter .write() @@ -929,7 +903,7 @@ pub(crate) async fn fetch_signing_keys( hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), }; - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_signature_ratelimiter .read() @@ -950,7 +924,7 @@ pub(crate) async fn fetch_signing_keys( trace!("Loading signing keys for {}", origin); - let mut result: BTreeMap<_, _> = db + let mut result: BTreeMap<_, _> = services() .globals .signing_keys_for(origin)? .into_iter() @@ -963,14 +937,14 @@ pub(crate) async fn fetch_signing_keys( debug!("Fetching signing keys for {} over federation", origin); - if let Some(server_key) = db + if let Some(server_key) = services() .sending - .send_federation_request(&db.globals, origin, get_server_keys::v2::Request::new()) + .send_federation_request(origin, get_server_keys::v2::Request::new()) .await .ok() .and_then(|resp| resp.server_key.deserialize().ok()) { - db.globals.add_signing_key(origin, server_key.clone())?; + services().globals.add_signing_key(origin, server_key.clone())?; result.extend( server_key @@ -990,12 +964,11 @@ pub(crate) async fn fetch_signing_keys( } } - for server in db.globals.trusted_servers() { + for server in services().globals.trusted_servers() { debug!("Asking {} for {}'s signing key", server, origin); - if let Some(server_keys) = db + if let Some(server_keys) = services() .sending .send_federation_request( - &db.globals, server, get_remote_server_keys::v2::Request::new( origin, @@ -1018,7 +991,7 @@ pub(crate) async fn fetch_signing_keys( { trace!("Got signing keys: {:?}", server_keys); for k in server_keys { - db.globals.add_signing_key(origin, k.clone())?; + services().globals.add_signing_key(origin, k.clone())?; result.extend( k.verify_keys .into_iter() @@ -1047,11 +1020,10 @@ pub(crate) async fn fetch_signing_keys( )) } -#[tracing::instrument(skip(starting_events, db))] +#[tracing::instrument(skip(starting_events))] pub(crate) async fn get_auth_chain<'a>( room_id: &RoomId, starting_events: Vec<Arc<EventId>>, - db: &'a Database, ) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> { const NUM_BUCKETS: usize = 50; @@ -1059,7 +1031,7 @@ pub(crate) async fn get_auth_chain<'a>( let mut i = 0; for id in starting_events { - let short = db.rooms.get_or_create_shorteventid(&id, &db.globals)?; + let short = services().rooms.get_or_create_shorteventid(&id)?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; @@ -1078,7 +1050,7 @@ pub(crate) async fn get_auth_chain<'a>( } let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&chunk_key)? { + if let Some(cached) = services().rooms.get_auth_chain_from_cache(&chunk_key)? { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; @@ -1090,13 +1062,13 @@ pub(crate) async fn get_auth_chain<'a>( let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = db.rooms.get_auth_chain_from_cache(&[sevent_id])? { + if let Some(cached) = services().rooms.get_auth_chain_from_cache(&[sevent_id])? { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; - let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id, db)?); - db.rooms + let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?); + services().rooms .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; println!( "cache missed event {} with auth chain len {}", @@ -1118,7 +1090,7 @@ pub(crate) async fn get_auth_chain<'a>( misses2 ); let chunk_cache = Arc::new(chunk_cache); - db.rooms + services().rooms .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; full_auth_chain.extend(chunk_cache.iter()); } @@ -1132,28 +1104,27 @@ pub(crate) async fn get_auth_chain<'a>( Ok(full_auth_chain .into_iter() - .filter_map(move |sid| db.rooms.get_eventid_from_short(sid).ok())) + .filter_map(move |sid| services().rooms.get_eventid_from_short(sid).ok())) } -#[tracing::instrument(skip(event_id, db))] +#[tracing::instrument(skip(event_id))] fn get_auth_chain_inner( room_id: &RoomId, event_id: &EventId, - db: &Database, ) -> Result<HashSet<u64>> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { - match db.rooms.get_pdu(&event_id) { + match services().rooms.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); } for auth_event in &pdu.auth_events { - let sauthevent = db + let sauthevent = services() .rooms - .get_or_create_shorteventid(auth_event, &db.globals)?; + .get_or_create_shorteventid(auth_event)?; if !found.contains(&sauthevent) { found.insert(sauthevent); @@ -1179,10 +1150,9 @@ fn get_auth_chain_inner( /// /// - Only works if a user of this server is currently invited or joined the room pub async fn get_event_route( - db: DatabaseGuard, body: Ruma<get_event::v1::IncomingRequest>, ) -> Result<get_event::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1191,7 +1161,7 @@ pub async fn get_event_route( .as_ref() .expect("server is authenticated"); - let event = db + let event = services() .rooms .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1204,7 +1174,7 @@ pub async fn get_event_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if !db.rooms.server_in_room(sender_servername, room_id)? { + if !services().rooms.server_in_room(sender_servername, room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", @@ -1212,7 +1182,7 @@ pub async fn get_event_route( } Ok(get_event::v1::Response { - origin: db.globals.server_name().to_owned(), + origin: services().globals.server_name().to_owned(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(), pdu: PduEvent::convert_to_outgoing_federation_event(event), }) @@ -1222,10 +1192,9 @@ pub async fn get_event_route( /// /// Retrieves events that the sender is missing. pub async fn get_missing_events_route( - db: DatabaseGuard, body: Ruma<get_missing_events::v1::IncomingRequest>, ) -> Result<get_missing_events::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1234,21 +1203,21 @@ pub async fn get_missing_events_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; let mut queued_events = body.latest_events.clone(); let mut events = Vec::new(); let mut i = 0; while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { - if let Some(pdu) = db.rooms.get_pdu_json(&queued_events[i])? { + if let Some(pdu) = services().rooms.get_pdu_json(&queued_events[i])? { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -1295,10 +1264,9 @@ pub async fn get_missing_events_route( /// /// - This does not include the event itself pub async fn get_event_authorization_route( - db: DatabaseGuard, body: Ruma<get_event_authorization::v1::IncomingRequest>, ) -> Result<get_event_authorization::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1307,16 +1275,16 @@ pub async fn get_event_authorization_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let event = db + let event = services() .rooms .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1329,11 +1297,11 @@ pub async fn get_event_authorization_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)], &db).await?; + let auth_chain_ids = get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok()?) + .filter_map(|id| services().rooms.get_pdu_json(&id).ok()?) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1343,10 +1311,9 @@ pub async fn get_event_authorization_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_route( - db: DatabaseGuard, body: Ruma<get_room_state::v1::IncomingRequest>, ) -> Result<get_room_state::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1355,16 +1322,16 @@ pub async fn get_room_state_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let shortstatehash = db + let shortstatehash = services() .rooms .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( @@ -1372,25 +1339,25 @@ pub async fn get_room_state_route( "Pdu state not found.", ))?; - let pdus = db + let pdus = services() .rooms .state_full_ids(shortstatehash) .await? .into_iter() .map(|(_, id)| { PduEvent::convert_to_outgoing_federation_event( - db.rooms.get_pdu_json(&id).unwrap().unwrap(), + services().rooms.get_pdu_json(&id).unwrap().unwrap(), ) }) .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids .map(|id| { - db.rooms.get_pdu_json(&id).map(|maybe_json| { + services().rooms.get_pdu_json(&id).map(|maybe_json| { PduEvent::convert_to_outgoing_federation_event(maybe_json.unwrap()) }) }) @@ -1404,10 +1371,9 @@ pub async fn get_room_state_route( /// /// Retrieves the current state of the room. pub async fn get_room_state_ids_route( - db: DatabaseGuard, body: Ruma<get_room_state_ids::v1::IncomingRequest>, ) -> Result<get_room_state_ids::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1416,16 +1382,16 @@ pub async fn get_room_state_ids_route( .as_ref() .expect("server is authenticated"); - if !db.rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - let shortstatehash = db + let shortstatehash = services() .rooms .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( @@ -1433,7 +1399,7 @@ pub async fn get_room_state_ids_route( "Pdu state not found.", ))?; - let pdu_ids = db + let pdu_ids = services() .rooms .state_full_ids(shortstatehash) .await? @@ -1442,7 +1408,7 @@ pub async fn get_room_state_ids_route( .collect(); let auth_chain_ids = - get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)], &db).await?; + get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), @@ -1454,14 +1420,13 @@ pub async fn get_room_state_ids_route( /// /// Creates a join template. pub async fn create_join_event_template_route( - db: DatabaseGuard, body: Ruma<prepare_join_event::v1::IncomingRequest>, ) -> Result<prepare_join_event::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - if !db.rooms.exists(&body.room_id)? { + if !services().rooms.exists(&body.room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", @@ -1473,11 +1438,21 @@ pub async fn create_join_event_template_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(body.room_id.to_owned()) + .or_default(), + ); + let state_lock = mutex_state.lock().await; // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = - db.rooms + services().rooms .room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event @@ -1502,7 +1477,8 @@ pub async fn create_join_event_template_route( } } - if !body.ver.contains(&room_version_id) { + let room_version_id = services().rooms.state.get_room_version(&body.room_id); + if !body.ver.contains(room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: room_version_id, @@ -1523,10 +1499,15 @@ pub async fn create_join_event_template_route( }) .expect("member event is valid value"); - let state_key = body.user_id.to_string(); - let kind = StateEventType::RoomMember; + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event(PduBuilder { + event_type: RoomEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, &body.user_id, &body.room_id, &state_lock); - let (pdu, pdu_json) = create_hash_and_sign_event(); + drop(state_lock); Ok(prepare_join_event::v1::Response { room_version: Some(room_version_id), @@ -1535,26 +1516,25 @@ pub async fn create_join_event_template_route( } async fn create_join_event( - db: &DatabaseGuard, sender_servername: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result<RoomState> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - if !db.rooms.exists(room_id)? { + if !services().rooms.exists(room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", )); } - acl_check(sender_servername, room_id, db)?; + acl_check(sender_servername, room_id)?; // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = db + let join_rules_event = services() .rooms .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; @@ -1581,7 +1561,7 @@ async fn create_join_event( } // We need to return the state prior to joining, let's keep a reference to that here - let shortstatehash = db + let shortstatehash = services() .rooms .current_shortstatehash(room_id)? .ok_or(Error::BadRequest( @@ -1593,7 +1573,7 @@ async fn create_join_event( // let mut auth_cache = EventMap::new(); // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match crate::pdu::gen_event_id_canonical_json(pdu, &db) { + let (event_id, value) = match gen_event_id_canonical_json(pdu) { Ok(t) => t, Err(_) => { // Event could not be converted to canonical json @@ -1614,7 +1594,7 @@ async fn create_join_event( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; let mutex = Arc::clone( - db.globals + services().globals .roomid_mutex_federation .write() .unwrap() @@ -1622,7 +1602,7 @@ async fn create_join_event( .or_default(), ); let mutex_lock = mutex.lock().await; - let pdu_id = handle_incoming_pdu(&origin, &event_id, room_id, value, true, db, &pub_key_map) + let pdu_id = services().rooms.event_handler.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) .await .map_err(|e| { warn!("Error while handling incoming send join PDU: {}", e); @@ -1637,32 +1617,29 @@ async fn create_join_event( ))?; drop(mutex_lock); - let state_ids = db.rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_full_ids(shortstatehash).await?; let auth_chain_ids = get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), - db, ) .await?; - let servers = db + let servers = services() .rooms .room_servers(room_id) .filter_map(|r| r.ok()) - .filter(|server| &**server != db.globals.server_name()); - - db.sending.send_pdu(servers, &pdu_id)?; + .filter(|server| &**server != services().globals.server_name()); - db.flush()?; + services().sending.send_pdu(servers, &pdu_id)?; Ok(RoomState { auth_chain: auth_chain_ids - .filter_map(|id| db.rooms.get_pdu_json(&id).ok().flatten()) + .filter_map(|id| services().rooms.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() - .filter_map(|(_, id)| db.rooms.get_pdu_json(id).ok().flatten()) + .filter_map(|(_, id)| services().rooms.get_pdu_json(id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1672,7 +1649,6 @@ async fn create_join_event( /// /// Submits a signed join event. pub async fn create_join_event_v1_route( - db: DatabaseGuard, body: Ruma<create_join_event::v1::IncomingRequest>, ) -> Result<create_join_event::v1::Response> { let sender_servername = body @@ -1680,7 +1656,7 @@ pub async fn create_join_event_v1_route( .as_ref() .expect("server is authenticated"); - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v1::Response { room_state }) } @@ -1689,7 +1665,6 @@ pub async fn create_join_event_v1_route( /// /// Submits a signed join event. pub async fn create_join_event_v2_route( - db: DatabaseGuard, body: Ruma<create_join_event::v2::IncomingRequest>, ) -> Result<create_join_event::v2::Response> { let sender_servername = body @@ -1697,7 +1672,7 @@ pub async fn create_join_event_v2_route( .as_ref() .expect("server is authenticated"); - let room_state = create_join_event(&db, sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; Ok(create_join_event::v2::Response { room_state }) } @@ -1706,10 +1681,9 @@ pub async fn create_join_event_v2_route( /// /// Invites a remote user to a room. pub async fn create_invite_route( - db: DatabaseGuard, body: Ruma<create_invite::v2::IncomingRequest>, ) -> Result<create_invite::v2::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1718,9 +1692,9 @@ pub async fn create_invite_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id, &db)?; + acl_check(sender_servername, &body.room_id)?; - if !db.rooms.is_supported_version(&db, &body.room_version) { + if !services().rooms.is_supported_version(&body.room_version) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: body.room_version.clone(), @@ -1733,8 +1707,8 @@ pub async fn create_invite_route( .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid."))?; ruma::signatures::hash_and_sign_event( - db.globals.server_name().as_str(), - db.globals.keypair(), + services().globals.server_name().as_str(), + services().globals.keypair(), &mut signed_event, &body.room_version, ) @@ -1793,20 +1767,17 @@ pub async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); // If the room already exists, the remote server will notify us about the join via /send - if !db.rooms.exists(&pdu.room_id)? { - db.rooms.update_membership( + if !services().rooms.exists(&pdu.room_id)? { + services().rooms.update_membership( &body.room_id, &invited_user, MembershipState::Invite, &sender, Some(invite_state), - &db, true, )?; } - db.flush()?; - Ok(create_invite::v2::Response { event: PduEvent::convert_to_outgoing_federation_event(signed_event), }) @@ -1816,10 +1787,9 @@ pub async fn create_invite_route( /// /// Gets information on all devices of the user. pub async fn get_devices_route( - db: DatabaseGuard, body: Ruma<get_devices::v1::IncomingRequest>, ) -> Result<get_devices::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1830,19 +1800,19 @@ pub async fn get_devices_route( Ok(get_devices::v1::Response { user_id: body.user_id.clone(), - stream_id: db + stream_id: services() .users .get_devicelist_version(&body.user_id)? .unwrap_or(0) .try_into() .expect("version will not grow that large"), - devices: db + devices: services() .users .all_devices_metadata(&body.user_id) .filter_map(|r| r.ok()) .filter_map(|metadata| { Some(UserDevice { - keys: db + keys: services() .users .get_device_keys(&body.user_id, &metadata.device_id) .ok()??, @@ -1851,10 +1821,10 @@ pub async fn get_devices_route( }) }) .collect(), - master_key: db + master_key: services() .users .get_master_key(&body.user_id, |u| u.server_name() == sender_servername)?, - self_signing_key: db + self_signing_key: services() .users .get_self_signing_key(&body.user_id, |u| u.server_name() == sender_servername)?, }) @@ -1864,14 +1834,13 @@ pub async fn get_devices_route( /// /// Resolve a room alias to a room id. pub async fn get_room_information_route( - db: DatabaseGuard, body: Ruma<get_room_information::v1::IncomingRequest>, ) -> Result<get_room_information::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - let room_id = db + let room_id = services() .rooms .id_from_alias(&body.room_alias)? .ok_or(Error::BadRequest( @@ -1881,7 +1850,7 @@ pub async fn get_room_information_route( Ok(get_room_information::v1::Response { room_id, - servers: vec![db.globals.server_name().to_owned()], + servers: vec![services().globals.server_name().to_owned()], }) } @@ -1889,10 +1858,9 @@ pub async fn get_room_information_route( /// /// Gets information on a profile. pub async fn get_profile_information_route( - db: DatabaseGuard, body: Ruma<get_profile_information::v1::IncomingRequest>, ) -> Result<get_profile_information::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1901,17 +1869,17 @@ pub async fn get_profile_information_route( let mut blurhash = None; match &body.field { - Some(ProfileField::DisplayName) => displayname = db.users.displayname(&body.user_id)?, + Some(ProfileField::DisplayName) => displayname = services().users.displayname(&body.user_id)?, Some(ProfileField::AvatarUrl) => { - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)? + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)? } // TODO: what to do with custom Some(_) => {} None => { - displayname = db.users.displayname(&body.user_id)?; - avatar_url = db.users.avatar_url(&body.user_id)?; - blurhash = db.users.blurhash(&body.user_id)?; + displayname = services().users.displayname(&body.user_id)?; + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)?; } } @@ -1926,10 +1894,9 @@ pub async fn get_profile_information_route( /// /// Gets devices and identity keys for the given users. pub async fn get_keys_route( - db: DatabaseGuard, body: Ruma<get_keys::v1::Request>, ) -> Result<get_keys::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } @@ -1937,12 +1904,9 @@ pub async fn get_keys_route( None, &body.device_keys, |u| Some(u.server_name()) == body.sender_servername.as_deref(), - &db, ) .await?; - db.flush()?; - Ok(get_keys::v1::Response { device_keys: result.device_keys, master_keys: result.master_keys, @@ -1954,16 +1918,13 @@ pub async fn get_keys_route( /// /// Claims one-time keys. pub async fn claim_keys_route( - db: DatabaseGuard, body: Ruma<claim_keys::v1::Request>, ) -> Result<claim_keys::v1::Response> { - if !db.globals.allow_federation() { + if !services().globals.allow_federation() { return Err(Error::bad_config("Federation is disabled.")); } - let result = claim_keys_helper(&body.one_time_keys, &db).await?; - - db.flush()?; + let result = claim_keys_helper(&body.one_time_keys).await?; Ok(claim_keys::v1::Response { one_time_keys: result.one_time_keys, @@ -1974,7 +1935,6 @@ pub async fn claim_keys_route( pub(crate) async fn fetch_required_signing_keys( event: &BTreeMap<String, CanonicalJsonValue>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, ) -> Result<()> { let signatures = event .get("signatures") @@ -1996,7 +1956,6 @@ pub(crate) async fn fetch_required_signing_keys( let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); let fetch_res = fetch_signing_keys( - db, signature_server.as_str().try_into().map_err(|_| { Error::BadServerResponse("Invalid servername in signatures of server response pdu.") })?, @@ -2028,7 +1987,6 @@ fn get_server_keys_from_cache( servers: &mut BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>>, room_version: &RoomVersionId, pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, ) -> Result<()> { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); @@ -2043,7 +2001,7 @@ fn get_server_keys_from_cache( let event_id = <&EventId>::try_from(event_id.as_str()) .expect("ruma's reference hashes are valid event ids"); - if let Some((time, tries)) = db + if let Some((time, tries)) = services() .globals .bad_event_ratelimiter .read() @@ -2092,7 +2050,7 @@ fn get_server_keys_from_cache( trace!("Loading signing keys for {}", origin); - let result: BTreeMap<_, _> = db + let result: BTreeMap<_, _> = services() .globals .signing_keys_for(origin)? .into_iter() @@ -2114,7 +2072,6 @@ pub(crate) async fn fetch_join_signing_keys( event: &create_join_event::v2::Response, room_version: &RoomVersionId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, - db: &Database, ) -> Result<()> { let mut servers: BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>> = BTreeMap::new(); @@ -2127,10 +2084,10 @@ pub(crate) async fn fetch_join_signing_keys( // Try to fetch keys, failure is okay // Servers we couldn't find in the cache will be added to `servers` for pdu in &event.room_state.state { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); + let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); } for pdu in &event.room_state.auth_chain { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm, db); + let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); } drop(pkm); @@ -2141,12 +2098,11 @@ pub(crate) async fn fetch_join_signing_keys( return Ok(()); } - for server in db.globals.trusted_servers() { + for server in services().globals.trusted_servers() { trace!("Asking batch signing keys from trusted server {}", server); - if let Ok(keys) = db + if let Ok(keys) = services() .sending .send_federation_request( - &db.globals, server, get_remote_server_keys_batch::v2::Request { server_keys: servers.clone(), @@ -2164,7 +2120,7 @@ pub(crate) async fn fetch_join_signing_keys( // TODO: Check signature from trusted server? servers.remove(&k.server_name); - let result = db + let result = services() .globals .add_signing_key(&k.server_name, k.clone())? .into_iter() @@ -2184,9 +2140,8 @@ pub(crate) async fn fetch_join_signing_keys( .into_iter() .map(|(server, _)| async move { ( - db.sending + services().sending .send_federation_request( - &db.globals, &server, get_server_keys::v2::Request::new(), ) @@ -2198,7 +2153,7 @@ pub(crate) async fn fetch_join_signing_keys( while let Some(result) = futures.next().await { if let (Ok(get_keys_response), origin) = result { - let result: BTreeMap<_, _> = db + let result: BTreeMap<_, _> = services() .globals .add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? .into_iter() @@ -2216,8 +2171,8 @@ pub(crate) async fn fetch_join_signing_keys( } /// Returns Ok if the acl allows the server -fn acl_check(server_name: &ServerName, room_id: &RoomId, db: &Database) -> Result<()> { - let acl_event = match db +fn acl_check(server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let acl_event = match services() .rooms .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { |