summaryrefslogtreecommitdiff
path: root/src/database/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/database/mod.rs')
-rw-r--r--src/database/mod.rs147
1 files changed, 136 insertions, 11 deletions
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 78bb358..e247d9f 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -1,7 +1,10 @@
pub mod abstraction;
pub mod key_value;
-use crate::{services, utils, Config, Error, PduEvent, Result, Services, SERVICES};
+use crate::{
+ service::rooms::timeline::PduCount, services, utils, Config, Error, PduEvent, Result, Services,
+ SERVICES,
+};
use abstraction::{KeyValueDatabaseEngine, KvTree};
use directories::ProjectDirs;
use lru_cache::LruCache;
@@ -15,6 +18,7 @@ use ruma::{
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
UserId,
};
+use serde::Deserialize;
use std::{
collections::{BTreeMap, HashMap, HashSet},
fs::{self, remove_dir_all},
@@ -22,7 +26,9 @@ use std::{
mem::size_of,
path::Path,
sync::{Arc, Mutex, RwLock},
+ time::Duration,
};
+use tokio::time::interval;
use tracing::{debug, error, info, warn};
@@ -77,6 +83,8 @@ pub struct KeyValueDatabase {
pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub(super) publicroomids: Arc<dyn KvTree>,
+ pub(super) threadid_userids: Arc<dyn KvTree>, // ThreadId = RoomId + Count
+
pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
@@ -125,6 +133,8 @@ pub struct KeyValueDatabase {
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
pub(super) softfailedeventids: Arc<dyn KvTree>,
+ /// ShortEventId + ShortEventId -> ().
+ pub(super) tofrom_relation: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(super) referencedevents: Arc<dyn KvTree>,
@@ -161,7 +171,7 @@ pub struct KeyValueDatabase {
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<OwnedRoomId, Arc<HashSet<OwnedUserId>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>,
- pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, u64>>,
+ pub(super) lasttimelinecount_cache: Mutex<HashMap<OwnedRoomId, PduCount>>,
}
impl KeyValueDatabase {
@@ -257,6 +267,10 @@ impl KeyValueDatabase {
}
};
+ if config.registration_token == Some(String::new()) {
+ return Err(Error::bad_config("Registration token is empty"));
+ }
+
if config.max_request_size < 1024 {
error!(?config.max_request_size, "Max request size is less than 1KB. Please increase it.");
}
@@ -299,6 +313,8 @@ impl KeyValueDatabase {
aliasid_alias: builder.open_tree("aliasid_alias")?,
publicroomids: builder.open_tree("publicroomids")?,
+ threadid_userids: builder.open_tree("threadid_userids")?,
+
tokenids: builder.open_tree("tokenids")?,
roomserverids: builder.open_tree("roomserverids")?,
@@ -339,6 +355,7 @@ impl KeyValueDatabase {
eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?,
softfailedeventids: builder.open_tree("softfailedeventids")?,
+ tofrom_relation: builder.open_tree("tofrom_relation")?,
referencedevents: builder.open_tree("referencedevents")?,
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
@@ -408,7 +425,7 @@ impl KeyValueDatabase {
}
// If the database has any data, perform data migrations before starting
- let latest_database_version = 12;
+ let latest_database_version = 13;
if services().users.count()? > 0 {
// MIGRATIONS
@@ -577,8 +594,8 @@ impl KeyValueDatabase {
services().rooms.state_compressor.save_state_from_diff(
current_sstatehash,
- statediffnew,
- statediffremoved,
+ Arc::new(statediffnew),
+ Arc::new(statediffremoved),
2, // every state change is 2 event changes on average
states_parents,
)?;
@@ -800,10 +817,17 @@ impl KeyValueDatabase {
}
if services().globals.database_version()? < 12 {
- for username in services().users.list_local_users().unwrap() {
- let user =
- UserId::parse_with_server_name(username, services().globals.server_name())
- .unwrap();
+ for username in services().users.list_local_users()? {
+ let user = match UserId::parse_with_server_name(
+ username.clone(),
+ services().globals.server_name(),
+ ) {
+ Ok(u) => u,
+ Err(e) => {
+ warn!("Invalid username {username}: {e}");
+ continue;
+ }
+ };
let raw_rules_list = services()
.account_data
@@ -870,6 +894,52 @@ impl KeyValueDatabase {
warn!("Migration: 11 -> 12 finished");
}
+ // This migration can be reused as-is anytime the server-default rules are updated.
+ if services().globals.database_version()? < 13 {
+ for username in services().users.list_local_users()? {
+ let user = match UserId::parse_with_server_name(
+ username.clone(),
+ services().globals.server_name(),
+ ) {
+ Ok(u) => u,
+ Err(e) => {
+ warn!("Invalid username {username}: {e}");
+ continue;
+ }
+ };
+
+ let raw_rules_list = services()
+ .account_data
+ .get(
+ None,
+ &user,
+ GlobalAccountDataEventType::PushRules.to_string().into(),
+ )
+ .unwrap()
+ .expect("Username is invalid");
+
+ let mut account_data =
+ serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
+
+ let user_default_rules = ruma::push::Ruleset::server_default(&user);
+ account_data
+ .content
+ .global
+ .update_with_server_default(user_default_rules);
+
+ services().account_data.update(
+ None,
+ &user,
+ GlobalAccountDataEventType::PushRules.to_string().into(),
+ &serde_json::to_value(account_data).expect("to json value always works"),
+ )?;
+ }
+
+ services().globals.bump_database_version(13)?;
+
+ warn!("Migration: 12 -> 13 finished");
+ }
+
assert_eq!(
services().globals.database_version().unwrap(),
latest_database_version
@@ -919,6 +989,9 @@ impl KeyValueDatabase {
services().sending.start_handler();
Self::start_cleanup_task().await;
+ if services().globals.allow_check_for_updates() {
+ Self::start_check_for_updates_task();
+ }
Ok(())
}
@@ -935,9 +1008,61 @@ impl KeyValueDatabase {
}
#[tracing::instrument]
- pub async fn start_cleanup_task() {
- use tokio::time::interval;
+ pub fn start_check_for_updates_task() {
+ tokio::spawn(async move {
+ let timer_interval = Duration::from_secs(60 * 60);
+ let mut i = interval(timer_interval);
+ loop {
+ i.tick().await;
+ let _ = Self::try_handle_updates().await;
+ }
+ });
+ }
+
+ async fn try_handle_updates() -> Result<()> {
+ let response = services()
+ .globals
+ .default_client()
+ .get("https://conduit.rs/check-for-updates/stable")
+ .send()
+ .await?;
+
+ #[derive(Deserialize)]
+ struct CheckForUpdatesResponseEntry {
+ id: u64,
+ date: String,
+ message: String,
+ }
+ #[derive(Deserialize)]
+ struct CheckForUpdatesResponse {
+ updates: Vec<CheckForUpdatesResponseEntry>,
+ }
+
+ let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?)
+ .map_err(|_| Error::BadServerResponse("Bad version check response"))?;
+
+ let mut last_update_id = services().globals.last_check_for_updates_id()?;
+ for update in response.updates {
+ last_update_id = last_update_id.max(update.id);
+ if update.id > services().globals.last_check_for_updates_id()? {
+ println!("{}", update.message);
+ services()
+ .admin
+ .send_message(RoomMessageEventContent::text_plain(format!(
+ "@room: The following is a message from the Conduit developers. It was sent on '{}':\n\n{}",
+ update.date, update.message
+ )))
+ }
+ }
+ services()
+ .globals
+ .update_check_for_updates_id(last_update_id)?;
+ Ok(())
+ }
+
+ #[tracing::instrument]
+ pub async fn start_cleanup_task() {
#[cfg(unix)]
use tokio::signal::unix::{signal, SignalKind};