diff options
Diffstat (limited to 'src/database/mod.rs')
-rw-r--r-- | src/database/mod.rs | 147 |
1 files changed, 136 insertions, 11 deletions
diff --git a/src/database/mod.rs b/src/database/mod.rs index 78bb358..e247d9f 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,7 +1,10 @@ pub mod abstraction; pub mod key_value; -use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES}; +use crate::{ + service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services, + SERVICES, +}; use abstraction::{KeyValueDatabaseEngine, KvTree}; use directories::ProjectDirs; use lru_cache::LruCache; @@ -15,6 +18,7 @@ use ruma::{ CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; +use serde::Deserialize; use std::{ collections::{BTreeMap, HashMap, HashSet}, fs::{self, remove_dir_all}, @@ -22,7 +26,9 @@ use std::{ mem::size_of, path::Path, sync::{Arc, Mutex, RwLock}, + time::Duration, }; +use tokio::time::interval; use tracing::{debug, error, info, warn}; @@ -77,6 +83,8 @@ pub struct KeyValueDatabase { pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count pub(super) publicroomids: Arc<dyn KvTree>, + pub(super) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count + pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount /// Participating servers in a room. @@ -125,6 +133,8 @@ pub struct KeyValueDatabase { pub(super) eventid_outlierpdu: Arc<dyn KvTree>, pub(super) softfailedeventids: Arc<dyn KvTree>, + /// ShortEventId + ShortEventId -> (). + pub(super) tofrom_relation: Arc<dyn KvTree>, /// RoomId + EventId -> Parent PDU EventId. pub(super) referencedevents: Arc<dyn KvTree>, @@ -161,7 +171,7 @@ pub struct KeyValueDatabase { pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>, pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>, pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>, - pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>, + pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>, } impl KeyValueDatabase { @@ -257,6 +267,10 @@ impl KeyValueDatabase { } }; + if config.registration_token == Some(String::new()) { + return Err(Error::bad_config("Registration token is empty")); + } + if config.max_request_size < 1024 { error!(?config.max_request_size, "Max request size is less than 1KB. Please increase it."); } @@ -299,6 +313,8 @@ impl KeyValueDatabase { aliasid_alias: builder.open_tree("aliasid_alias")?, publicroomids: builder.open_tree("publicroomids")?, + threadid_userids: builder.open_tree("threadid_userids")?, + tokenids: builder.open_tree("tokenids")?, roomserverids: builder.open_tree("roomserverids")?, @@ -339,6 +355,7 @@ impl KeyValueDatabase { eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, softfailedeventids: builder.open_tree("softfailedeventids")?, + tofrom_relation: builder.open_tree("tofrom_relation")?, referencedevents: builder.open_tree("referencedevents")?, roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, @@ -408,7 +425,7 @@ impl KeyValueDatabase { } // If the database has any data, perform data migrations before starting - let latest_database_version = 12; + let latest_database_version = 13; if services().users.count()? > 0 { // MIGRATIONS @@ -577,8 +594,8 @@ impl KeyValueDatabase { services().rooms.state_compressor.save_state_from_diff( current_sstatehash, - statediffnew, - statediffremoved, + Arc::new(statediffnew), + Arc::new(statediffremoved), 2, // every state change is 2 event changes on average states_parents, )?; @@ -800,10 +817,17 @@ impl KeyValueDatabase { } if services().globals.database_version()? < 12 { - for username in services().users.list_local_users().unwrap() { - let user = - UserId::parse_with_server_name(username, services().globals.server_name()) - .unwrap(); + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name( + username.clone(), + services().globals.server_name(), + ) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + } + }; let raw_rules_list = services() .account_data @@ -870,6 +894,52 @@ impl KeyValueDatabase { warn!("Migration: 11 -> 12 finished"); } + // This migration can be reused as-is anytime the server-default rules are updated. + if services().globals.database_version()? < 13 { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name( + username.clone(), + services().globals.server_name(), + ) { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + } + }; + + let raw_rules_list = services() + .account_data + .get( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + ) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = + serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap(); + + let user_default_rules = ruma::push::Ruleset::server_default(&user); + account_data + .content + .global + .update_with_server_default(user_default_rules); + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(13)?; + + warn!("Migration: 12 -> 13 finished"); + } + assert_eq!( services().globals.database_version().unwrap(), latest_database_version @@ -919,6 +989,9 @@ impl KeyValueDatabase { services().sending.start_handler(); Self::start_cleanup_task().await; + if services().globals.allow_check_for_updates() { + Self::start_check_for_updates_task(); + } Ok(()) } @@ -935,9 +1008,61 @@ impl KeyValueDatabase { } #[tracing::instrument] - pub async fn start_cleanup_task() { - use tokio::time::interval; + pub fn start_check_for_updates_task() { + tokio::spawn(async move { + let timer_interval = Duration::from_secs(60 * 60); + let mut i = interval(timer_interval); + loop { + i.tick().await; + let _ = Self::try_handle_updates().await; + } + }); + } + + async fn try_handle_updates() -> Result<()> { + let response = services() + .globals + .default_client() + .get("https://conduit.rs/check-for-updates/stable") + .send() + .await?; + + #[derive(Deserialize)] + struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, + } + #[derive(Deserialize)] + struct CheckForUpdatesResponse { + updates: Vec<CheckForUpdatesResponseEntry>, + } + + let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?) + .map_err(|_| Error::BadServerResponse("Bad version check response"))?; + + let mut last_update_id = services().globals.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > services().globals.last_check_for_updates_id()? { + println!("{}", update.message); + services() + .admin + .send_message(RoomMessageEventContent::text_plain(format!( + "@room: The following is a message from the Conduit developers. It was sent on '{}':\n\n{}", + update.date, update.message + ))) + } + } + services() + .globals + .update_check_for_updates_id(last_update_id)?; + Ok(()) + } + + #[tracing::instrument] + pub async fn start_cleanup_task() { #[cfg(unix)] use tokio::signal::unix::{signal, SignalKind}; |