summaryrefslogtreecommitdiff
path: root/src/database
diff options
context:
space:
mode:
authorTimo Kösters <timo@koesters.xyz>2022-10-05 20:33:55 +0200
committerNyaaori <+@nyaaori.cat>2022-10-10 14:02:01 +0200
commit33a2b2b7729bb40253fd174d99ad773869b5ecfe (patch)
treeebd0e70da4680452cd30940a64064f8708d0b432 /src/database
parent44fe6d1554eaa0a15314686974ab01f48c836588 (diff)
downloadconduit-33a2b2b7729bb40253fd174d99ad773869b5ecfe.zip
37 errors left
Diffstat (limited to 'src/database')
-rw-r--r--src/database/key_value/rooms/search.rs19
-rw-r--r--src/database/key_value/rooms/state.rs2
-rw-r--r--src/database/key_value/rooms/state_cache.rs488
-rw-r--r--src/database/key_value/rooms/timeline.rs43
4 files changed, 538 insertions, 14 deletions
diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs
index dfbdbc6..41df544 100644
--- a/src/database/key_value/rooms/search.rs
+++ b/src/database/key_value/rooms/search.rs
@@ -54,19 +54,20 @@ impl service::rooms::search::Data for KeyValueDatabase {
.map(|(key, _)| key[key.len() - size_of::<u64>()..].to_vec())
});
- Ok(utils::common_elements(iterators, |a, b| {
+ let common_elements = match utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
- })
- .map(|iter| {
- (
- Box::new(iter.map(move |id| {
+ }) {
+ Some(it) => it,
+ None => return Ok(None),
+ };
+
+ let mapped = common_elements.map(move |id| {
let mut pduid = prefix_clone.clone();
pduid.extend_from_slice(&id);
pduid
- })),
- words,
- )
- }))
+ });
+
+ Ok(Some((Box::new(mapped), words)))
}
}
diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs
index b2822b3..90ac0d5 100644
--- a/src/database/key_value/rooms/state.rs
+++ b/src/database/key_value/rooms/state.rs
@@ -49,7 +49,7 @@ impl service::rooms::state::Data for KeyValueDatabase {
fn set_forward_extremities<'a>(
&self,
room_id: &RoomId,
- event_ids: &mut dyn Iterator<Item = &'a EventId>,
+ event_ids: Vec<Box<EventId>>,
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec();
diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs
index 5f05485..4043bc4 100644
--- a/src/database/key_value/rooms/state_cache.rs
+++ b/src/database/key_value/rooms/state_cache.rs
@@ -1,6 +1,9 @@
-use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw};
+use std::{collections::HashSet, sync::Arc};
-use crate::{service, database::KeyValueDatabase, services, Result};
+use regex::Regex;
+use ruma::{UserId, RoomId, events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, ServerName};
+
+use crate::{service, database::KeyValueDatabase, services, Result, Error, utils};
impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
@@ -75,4 +78,485 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
Ok(())
}
+
+ fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
+ let mut joinedcount = 0_u64;
+ let mut invitedcount = 0_u64;
+ let mut joined_servers = HashSet::new();
+ let mut real_users = HashSet::new();
+
+ for joined in self.room_members(room_id).filter_map(|r| r.ok()) {
+ joined_servers.insert(joined.server_name().to_owned());
+ if joined.server_name() == services().globals.server_name()
+ && !services().users.is_deactivated(&joined).unwrap_or(true)
+ {
+ real_users.insert(joined);
+ }
+ joinedcount += 1;
+ }
+
+ for invited in self.room_members_invited(room_id).filter_map(|r| r.ok()) {
+ joined_servers.insert(invited.server_name().to_owned());
+ invitedcount += 1;
+ }
+
+ self.roomid_joinedcount
+ .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
+
+ self.roomid_invitedcount
+ .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
+
+ self.our_real_users_cache
+ .write()
+ .unwrap()
+ .insert(room_id.to_owned(), Arc::new(real_users));
+
+ self.appservice_in_room_cache
+ .write()
+ .unwrap()
+ .remove(room_id);
+
+ Ok(())
+ }
+
+ #[tracing::instrument(skip(self, room_id))]
+ fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<Box<UserId>>>> {
+ let maybe = self
+ .our_real_users_cache
+ .read()
+ .unwrap()
+ .get(room_id)
+ .cloned();
+ if let Some(users) = maybe {
+ Ok(users)
+ } else {
+ self.update_joined_count(room_id)?;
+ Ok(Arc::clone(
+ self.our_real_users_cache
+ .read()
+ .unwrap()
+ .get(room_id)
+ .unwrap(),
+ ))
+ }
+ }
+
+ #[tracing::instrument(skip(self, room_id, appservice))]
+ fn appservice_in_room(
+ &self,
+ room_id: &RoomId,
+ appservice: &(String, serde_yaml::Value),
+ ) -> Result<bool> {
+ let maybe = self
+ .appservice_in_room_cache
+ .read()
+ .unwrap()
+ .get(room_id)
+ .and_then(|map| map.get(&appservice.0))
+ .copied();
+
+ if let Some(b) = maybe {
+ Ok(b)
+ } else if let Some(namespaces) = appservice.1.get("namespaces") {
+ let users = namespaces
+ .get("users")
+ .and_then(|users| users.as_sequence())
+ .map_or_else(Vec::new, |users| {
+ users
+ .iter()
+ .filter_map(|users| Regex::new(users.get("regex")?.as_str()?).ok())
+ .collect::<Vec<_>>()
+ });
+
+ let bridge_user_id = appservice
+ .1
+ .get("sender_localpart")
+ .and_then(|string| string.as_str())
+ .and_then(|string| {
+ UserId::parse_with_server_name(string, services().globals.server_name()).ok()
+ });
+
+ let in_room = bridge_user_id
+ .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
+ || self.room_members(room_id).any(|userid| {
+ userid.map_or(false, |userid| {
+ users.iter().any(|r| r.is_match(userid.as_str()))
+ })
+ });
+
+ self.appservice_in_room_cache
+ .write()
+ .unwrap()
+ .entry(room_id.to_owned())
+ .or_default()
+ .insert(appservice.0.clone(), in_room);
+
+ Ok(in_room)
+ } else {
+ Ok(false)
+ }
+ }
+
+ /// Makes a user forget a room.
+ #[tracing::instrument(skip(self))]
+ fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
+ let mut userroom_id = user_id.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+
+ let mut roomuser_id = room_id.as_bytes().to_vec();
+ roomuser_id.push(0xff);
+ roomuser_id.extend_from_slice(user_id.as_bytes());
+
+ self.userroomid_leftstate.remove(&userroom_id)?;
+ self.roomuserid_leftcount.remove(&roomuser_id)?;
+
+ Ok(())
+ }
+
+ /// Returns an iterator of all servers participating in this room.
+ #[tracing::instrument(skip(self))]
+ fn room_servers<'a>(
+ &'a self,
+ room_id: &RoomId,
+ ) -> Box<dyn Iterator<Item = Result<Box<ServerName>>> + 'a> {
+ let mut prefix = room_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
+ ServerName::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("Server name in roomserverids is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
+ }))
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn server_in_room<'a>(&'a self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
+ let mut key = server.as_bytes().to_vec();
+ key.push(0xff);
+ key.extend_from_slice(room_id.as_bytes());
+
+ self.serverroomids.get(&key).map(|o| o.is_some())
+ }
+
+ /// Returns an iterator of all rooms a server participates in (as far as we know).
+ #[tracing::instrument(skip(self))]
+ fn server_rooms<'a>(
+ &'a self,
+ server: &ServerName,
+ ) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
+ let mut prefix = server.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
+ RoomId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
+ )
+ .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
+ }))
+ }
+
+ /// Returns an iterator over all joined members of a room.
+ #[tracing::instrument(skip(self))]
+ fn room_members<'a>(
+ &'a self,
+ room_id: &RoomId,
+ ) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
+ let mut prefix = room_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
+ UserId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
+ }))
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
+ self.roomid_joinedcount
+ .get(room_id.as_bytes())?
+ .map(|b| {
+ utils::u64_from_bytes(&b)
+ .map_err(|_| Error::bad_database("Invalid joinedcount in db."))
+ })
+ .transpose()
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
+ self.roomid_invitedcount
+ .get(room_id.as_bytes())?
+ .map(|b| {
+ utils::u64_from_bytes(&b)
+ .map_err(|_| Error::bad_database("Invalid joinedcount in db."))
+ })
+ .transpose()
+ }
+
+ /// Returns an iterator over all User IDs who ever joined a room.
+ #[tracing::instrument(skip(self))]
+ fn room_useroncejoined<'a>(
+ &'a self,
+ room_id: &RoomId,
+ ) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
+ let mut prefix = room_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.roomuseroncejoinedids
+ .scan_prefix(prefix)
+ .map(|(key, _)| {
+ UserId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("User ID in room_useroncejoined is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
+ }))
+ }
+
+ /// Returns an iterator over all invited members of a room.
+ #[tracing::instrument(skip(self))]
+ fn room_members_invited<'a>(
+ &'a self,
+ room_id: &RoomId,
+ ) -> Box<dyn Iterator<Item = Result<Box<UserId>>> + 'a> {
+ let mut prefix = room_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.roomuserid_invitecount
+ .scan_prefix(prefix)
+ .map(|(key, _)| {
+ UserId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
+ }))
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
+ let mut key = room_id.as_bytes().to_vec();
+ key.push(0xff);
+ key.extend_from_slice(user_id.as_bytes());
+
+ self.roomuserid_invitecount
+ .get(&key)?
+ .map_or(Ok(None), |bytes| {
+ Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
+ Error::bad_database("Invalid invitecount in db.")
+ })?))
+ })
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
+ let mut key = room_id.as_bytes().to_vec();
+ key.push(0xff);
+ key.extend_from_slice(user_id.as_bytes());
+
+ self.roomuserid_leftcount
+ .get(&key)?
+ .map(|bytes| {
+ utils::u64_from_bytes(&bytes)
+ .map_err(|_| Error::bad_database("Invalid leftcount in db."))
+ })
+ .transpose()
+ }
+
+ /// Returns an iterator over all rooms this user joined.
+ #[tracing::instrument(skip(self))]
+ fn rooms_joined<'a>(
+ &'a self,
+ user_id: &UserId,
+ ) -> Box<dyn Iterator<Item = Result<Box<RoomId>>> + 'a> {
+ Box::new(self.userroomid_joined
+ .scan_prefix(user_id.as_bytes().to_vec())
+ .map(|(key, _)| {
+ RoomId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
+ }))
+ }
+
+ /// Returns an iterator over all rooms a user was invited to.
+ #[tracing::instrument(skip(self))]
+ fn rooms_invited<'a>(
+ &'a self,
+ user_id: &UserId,
+ ) -> Box<dyn Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a> {
+ let mut prefix = user_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.userroomid_invitestate
+ .scan_prefix(prefix)
+ .map(|(key, state)| {
+ let room_id = RoomId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
+
+ let state = serde_json::from_slice(&state)
+ .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
+
+ Ok((room_id, state))
+ }))
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn invite_state(
+ &self,
+ user_id: &UserId,
+ room_id: &RoomId,
+ ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
+ let mut key = user_id.as_bytes().to_vec();
+ key.push(0xff);
+ key.extend_from_slice(room_id.as_bytes());
+
+ self.userroomid_invitestate
+ .get(&key)?
+ .map(|state| {
+ let state = serde_json::from_slice(&state)
+ .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
+
+ Ok(state)
+ })
+ .transpose()
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn left_state(
+ &self,
+ user_id: &UserId,
+ room_id: &RoomId,
+ ) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
+ let mut key = user_id.as_bytes().to_vec();
+ key.push(0xff);
+ key.extend_from_slice(room_id.as_bytes());
+
+ self.userroomid_leftstate
+ .get(&key)?
+ .map(|state| {
+ let state = serde_json::from_slice(&state)
+ .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
+
+ Ok(state)
+ })
+ .transpose()
+ }
+
+ /// Returns an iterator over all rooms a user left.
+ #[tracing::instrument(skip(self))]
+ fn rooms_left<'a>(
+ &'a self,
+ user_id: &UserId,
+ ) -> Box<dyn Iterator<Item = Result<(Box<RoomId>, Vec<Raw<AnySyncStateEvent>>)>> + 'a> {
+ let mut prefix = user_id.as_bytes().to_vec();
+ prefix.push(0xff);
+
+ Box::new(self.userroomid_leftstate
+ .scan_prefix(prefix)
+ .map(|(key, state)| {
+ let room_id = RoomId::parse(
+ utils::string_from_bytes(
+ key.rsplit(|&b| b == 0xff)
+ .next()
+ .expect("rsplit always returns an element"),
+ )
+ .map_err(|_| {
+ Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
+ })?,
+ )
+ .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
+
+ let state = serde_json::from_slice(&state)
+ .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
+
+ Ok((room_id, state))
+ }))
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
+ let mut userroom_id = user_id.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+
+ Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
+ let mut userroom_id = user_id.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+
+ Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
+ let mut userroom_id = user_id.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+
+ Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
+ }
+
+ #[tracing::instrument(skip(self))]
+ fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
+ let mut userroom_id = user_id.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+
+ Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
+ }
}
diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs
index 0b7286b..1723186 100644
--- a/src/database/key_value/rooms/timeline.rs
+++ b/src/database/key_value/rooms/timeline.rs
@@ -187,13 +187,29 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))
}
+ fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
+ self.pduid_pdu.insert(
+ pdu_id,
+ &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"))?;
+
+ self.lasttimelinecount_cache
+ .lock()
+ .unwrap()
+ .insert(pdu.room_id.clone(), count);
+
+ self.eventid_pduid
+ .insert(pdu.event_id.as_bytes(), &pdu_id)?;
+ self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
+
+ Ok(())
+ }
+
/// Removes a pdu and creates a new one with the same id.
fn replace_pdu(&self, pdu_id: &[u8], pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
- &serde_json::to_vec(pdu).expect("PduEvent::to_vec always works"),
- )?;
+ &serde_json::to_vec(pdu).expect("CanonicalJsonObject is always a valid"))?;
Ok(())
} else {
Err(Error::BadRequest(
@@ -306,4 +322,27 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
Ok((pdu_id, pdu))
})))
}
+
+ fn increment_notification_counts(&self, room_id: &RoomId, notifies: Vec<Box<UserId>>, highlights: Vec<Box<UserId>>) -> Result<()> {
+ let notifies_batch = Vec::new();
+ let highlights_batch = Vec::new();
+ for user in notifies {
+ let mut userroom_id = user.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+ notifies_batch.push(userroom_id);
+ }
+ for user in highlights {
+ let mut userroom_id = user.as_bytes().to_vec();
+ userroom_id.push(0xff);
+ userroom_id.extend_from_slice(room_id.as_bytes());
+ highlights_batch.push(userroom_id);
+ }
+
+ self.userroomid_notificationcount
+ .increment_batch(&mut notifies_batch.into_iter())?;
+ self.userroomid_highlightcount
+ .increment_batch(&mut highlights_batch.into_iter())?;
+ Ok(())
+ }
}