diff options
Diffstat (limited to 'src/api/server_server.rs')
-rw-r--r-- | src/api/server_server.rs | 379 |
1 files changed, 53 insertions, 326 deletions
diff --git a/src/api/server_server.rs b/src/api/server_server.rs index bacc1ac..9aa2beb 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -669,7 +669,7 @@ pub async fn send_transaction_message_route( } }; - acl_check(&sender_servername, &room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &room_id)?; let mutex = Arc::clone( services().globals @@ -727,7 +727,7 @@ pub async fn send_transaction_message_route( .event_ids .iter() .filter_map(|id| { - services().rooms.get_pdu_count(id).ok().flatten().map(|r| (id, r)) + services().rooms.timeline.get_pdu_count(id).ok().flatten().map(|r| (id, r)) }) .max_by_key(|(_, count)| *count) { @@ -744,7 +744,7 @@ pub async fn send_transaction_message_route( content: ReceiptEventContent(receipt_content), room_id: room_id.clone(), }; - services().rooms.edus.readreceipt_update( + services().rooms.edus.read_receipt.readreceipt_update( &user_id, &room_id, event, @@ -757,15 +757,15 @@ pub async fn send_transaction_message_route( } } Edu::Typing(typing) => { - if services().rooms.is_joined(&typing.user_id, &typing.room_id)? { + if services().rooms.state_cache.is_joined(&typing.user_id, &typing.room_id)? { if typing.typing { - services().rooms.edus.typing_add( + services().rooms.edus.typing.typing_add( &typing.user_id, &typing.room_id, 3000 + utils::millis_since_unix_epoch(), )?; } else { - services().rooms.edus.typing_remove( + services().rooms.edus.typing.typing_remove( &typing.user_id, &typing.room_id, )?; @@ -1031,7 +1031,7 @@ pub(crate) async fn get_auth_chain<'a>( let mut i = 0; for id in starting_events { - let short = services().rooms.get_or_create_shorteventid(&id)?; + let short = services().rooms.short.get_or_create_shorteventid(&id)?; let bucket_id = (short % NUM_BUCKETS as u64) as usize; buckets[bucket_id].insert((short, id.clone())); i += 1; @@ -1050,7 +1050,7 @@ pub(crate) async fn get_auth_chain<'a>( } let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = services().rooms.get_auth_chain_from_cache(&chunk_key)? { + if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&chunk_key)? { hits += 1; full_auth_chain.extend(cached.iter().copied()); continue; @@ -1062,13 +1062,14 @@ pub(crate) async fn get_auth_chain<'a>( let mut misses2 = 0; let mut i = 0; for (sevent_id, event_id) in chunk { - if let Some(cached) = services().rooms.get_auth_chain_from_cache(&[sevent_id])? { + if let Some(cached) = services().rooms.auth_chain.get_auth_chain_from_cache(&[sevent_id])? { hits2 += 1; chunk_cache.extend(cached.iter().copied()); } else { misses2 += 1; let auth_chain = Arc::new(get_auth_chain_inner(room_id, &event_id)?); services().rooms + .auth_chain .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; println!( "cache missed event {} with auth chain len {}", @@ -1091,7 +1092,7 @@ pub(crate) async fn get_auth_chain<'a>( ); let chunk_cache = Arc::new(chunk_cache); services().rooms - .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + .auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; full_auth_chain.extend(chunk_cache.iter()); } @@ -1104,7 +1105,7 @@ pub(crate) async fn get_auth_chain<'a>( Ok(full_auth_chain .into_iter() - .filter_map(move |sid| services().rooms.get_eventid_from_short(sid).ok())) + .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) } #[tracing::instrument(skip(event_id))] @@ -1116,14 +1117,14 @@ fn get_auth_chain_inner( let mut found = HashSet::new(); while let Some(event_id) = todo.pop() { - match services().rooms.get_pdu(&event_id) { + match services().rooms.timeline.get_pdu(&event_id) { Ok(Some(pdu)) => { if pdu.room_id != room_id { return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); } for auth_event in &pdu.auth_events { let sauthevent = services() - .rooms + .rooms.short .get_or_create_shorteventid(auth_event)?; if !found.contains(&sauthevent) { @@ -1162,7 +1163,7 @@ pub async fn get_event_route( .expect("server is authenticated"); let event = services() - .rooms + .rooms.timeline .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1174,7 +1175,7 @@ pub async fn get_event_route( let room_id = <&RoomId>::try_from(room_id_str) .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if !services().rooms.server_in_room(sender_servername, room_id)? { + if !services().rooms.state_cache.server_in_room(sender_servername, room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", @@ -1203,21 +1204,21 @@ pub async fn get_missing_events_route( .as_ref() .expect("server is authenticated"); - if !services().rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room", )); } - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; let mut queued_events = body.latest_events.clone(); let mut events = Vec::new(); let mut i = 0; while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { - if let Some(pdu) = services().rooms.get_pdu_json(&queued_events[i])? { + if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { let room_id_str = pdu .get("room_id") .and_then(|val| val.as_str()) @@ -1275,17 +1276,17 @@ pub async fn get_event_authorization_route( .as_ref() .expect("server is authenticated"); - if !services().rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; let event = services() - .rooms + .rooms.timeline .get_pdu_json(&body.event_id)? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?; @@ -1301,7 +1302,7 @@ pub async fn get_event_authorization_route( Ok(get_event_authorization::v1::Response { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.get_pdu_json(&id).ok()?) + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1322,17 +1323,17 @@ pub async fn get_room_state_route( .as_ref() .expect("server is authenticated"); - if !services().rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; let shortstatehash = services() - .rooms + .rooms.state_accessor .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -1340,13 +1341,13 @@ pub async fn get_room_state_route( ))?; let pdus = services() - .rooms + .rooms.state_accessor .state_full_ids(shortstatehash) .await? .into_iter() .map(|(_, id)| { PduEvent::convert_to_outgoing_federation_event( - services().rooms.get_pdu_json(&id).unwrap().unwrap(), + services().rooms.timeline.get_pdu_json(&id).unwrap().unwrap(), ) }) .collect(); @@ -1357,7 +1358,7 @@ pub async fn get_room_state_route( Ok(get_room_state::v1::Response { auth_chain: auth_chain_ids .map(|id| { - services().rooms.get_pdu_json(&id).map(|maybe_json| { + services().rooms.timeline.get_pdu_json(&id).map(|maybe_json| { PduEvent::convert_to_outgoing_federation_event(maybe_json.unwrap()) }) }) @@ -1382,17 +1383,17 @@ pub async fn get_room_state_ids_route( .as_ref() .expect("server is authenticated"); - if !services().rooms.server_in_room(sender_servername, &body.room_id)? { + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { return Err(Error::BadRequest( ErrorKind::Forbidden, "Server is not in room.", )); } - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; let shortstatehash = services() - .rooms + .rooms.state_accessor .pdu_shortstatehash(&body.event_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, @@ -1400,7 +1401,7 @@ pub async fn get_room_state_ids_route( ))?; let pdu_ids = services() - .rooms + .rooms.state_accessor .state_full_ids(shortstatehash) .await? .into_iter() @@ -1426,7 +1427,7 @@ pub async fn create_join_event_template_route( return Err(Error::bad_config("Federation is disabled.")); } - if !services().rooms.exists(&body.room_id)? { + if !services().rooms.metadata.exists(&body.room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", @@ -1438,7 +1439,7 @@ pub async fn create_join_event_template_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; let mutex_state = Arc::clone( services().globals @@ -1452,7 +1453,7 @@ pub async fn create_join_event_template_route( // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = - services().rooms + services().rooms.state_accessor .room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event @@ -1477,8 +1478,8 @@ pub async fn create_join_event_template_route( } } - let room_version_id = services().rooms.state.get_room_version(&body.room_id); - if !body.ver.contains(room_version_id) { + let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + if !body.ver.contains(&room_version_id) { return Err(Error::BadRequest( ErrorKind::IncompatibleRoomVersion { room_version: room_version_id, @@ -1505,7 +1506,7 @@ pub async fn create_join_event_template_route( unsigned: None, state_key: Some(body.user_id.to_string()), redacts: None, - }, &body.user_id, &body.room_id, &state_lock); + }, &body.user_id, &body.room_id, &state_lock)?; drop(state_lock); @@ -1524,18 +1525,18 @@ async fn create_join_event( return Err(Error::bad_config("Federation is disabled.")); } - if !services().rooms.exists(room_id)? { + if !services().rooms.metadata.exists(room_id)? { return Err(Error::BadRequest( ErrorKind::NotFound, "Room is unknown to this server.", )); } - acl_check(sender_servername, room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, room_id)?; // TODO: Conduit does not implement restricted join rules yet, we always reject let join_rules_event = services() - .rooms + .rooms.state_accessor .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; let join_rules_event_content: Option<RoomJoinRulesEventContent> = join_rules_event @@ -1562,8 +1563,8 @@ async fn create_join_event( // We need to return the state prior to joining, let's keep a reference to that here let shortstatehash = services() - .rooms - .current_shortstatehash(room_id)? + .rooms.state + .get_room_shortstatehash(room_id)? .ok_or(Error::BadRequest( ErrorKind::NotFound, "Pdu state not found.", @@ -1602,22 +1603,15 @@ async fn create_join_event( .or_default(), ); let mutex_lock = mutex.lock().await; - let pdu_id = services().rooms.event_handler.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) - .await - .map_err(|e| { - warn!("Error while handling incoming send join PDU: {}", e); - Error::BadRequest( - ErrorKind::InvalidParam, - "Error while handling incoming PDU.", - ) - })? + let pdu_id: Vec<u8> = services().rooms.event_handler.handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .await? .ok_or(Error::BadRequest( ErrorKind::InvalidParam, "Could not accept incoming PDU as timeline event.", ))?; drop(mutex_lock); - let state_ids = services().rooms.state_full_ids(shortstatehash).await?; + let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; let auth_chain_ids = get_auth_chain( room_id, state_ids.iter().map(|(_, id)| id.clone()).collect(), @@ -1626,6 +1620,7 @@ async fn create_join_event( let servers = services() .rooms + .state_cache .room_servers(room_id) .filter_map(|r| r.ok()) .filter(|server| &**server != services().globals.server_name()); @@ -1634,12 +1629,12 @@ async fn create_join_event( Ok(RoomState { auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.get_pdu_json(&id).ok().flatten()) + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), state: state_ids .iter() - .filter_map(|(_, id)| services().rooms.get_pdu_json(id).ok().flatten()) + .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) .map(PduEvent::convert_to_outgoing_federation_event) .collect(), }) @@ -1692,7 +1687,7 @@ pub async fn create_invite_route( .as_ref() .expect("server is authenticated"); - acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(&sender_servername, &body.room_id)?; if !services().rooms.is_supported_version(&body.room_version) { return Err(Error::BadRequest( @@ -1767,8 +1762,8 @@ pub async fn create_invite_route( invite_state.push(pdu.to_stripped_state_event()); // If the room already exists, the remote server will notify us about the join via /send - if !services().rooms.exists(&pdu.room_id)? { - services().rooms.update_membership( + if !services().rooms.metadata.exists(&pdu.room_id)? { + services().rooms.state_cache.update_membership( &body.room_id, &invited_user, MembershipState::Invite, @@ -1931,274 +1926,6 @@ pub async fn claim_keys_route( }) } -#[tracing::instrument(skip_all)] -pub(crate) async fn fetch_required_signing_keys( - event: &BTreeMap<String, CanonicalJsonValue>, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<()> { - let signatures = event - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; - - // We go through all the signatures we see on the value and fetch the corresponding signing - // keys - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); - - let fetch_res = fetch_signing_keys( - signature_server.as_str().try_into().map_err(|_| { - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?, - signature_ids, - ) - .await; - - let keys = match fetch_res { - Ok(keys) => keys, - Err(_) => { - warn!("Signature verification failed: Could not fetch signing key.",); - continue; - } - }; - - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(signature_server.clone(), keys); - } - - Ok(()) -} - -// Gets a list of servers for which we don't have the signing key yet. We go over -// the PDUs and either cache the key or add it to the list that needs to be retrieved. -fn get_server_keys_from_cache( - pdu: &RawJsonValue, - servers: &mut BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>>, - room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(30) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let signatures = value - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::<Vec<_>>(); - - let contains_all_ids = - |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id)); - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|_| { - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - trace!("Loading signing keys for {}", origin); - - let result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - trace!("Signing key not loaded for {}", origin); - servers.insert(origin.to_owned(), BTreeMap::new()); - } - - pub_key_map.insert(origin.to_string(), result); - } - - Ok(()) -} - -pub(crate) async fn fetch_join_signing_keys( - event: &create_join_event::v2::Response, - room_version: &RoomVersionId, - pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, -) -> Result<()> { - let mut servers: BTreeMap<Box<ServerName>, BTreeMap<Box<ServerSigningKeyId>, QueryCriteria>> = - BTreeMap::new(); - - { - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - - // Try to fetch keys, failure is okay - // Servers we couldn't find in the cache will be added to `servers` - for pdu in &event.room_state.state { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); - } - for pdu in &event.room_state.auth_chain { - let _ = get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); - } - - drop(pkm); - } - - if servers.is_empty() { - // We had all keys locally - return Ok(()); - } - - for server in services().globals.trusted_servers() { - trace!("Asking batch signing keys from trusted server {}", server); - if let Ok(keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - trace!("Got signing keys: {:?}", keys); - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - for k in keys.server_keys { - let k = k.deserialize().unwrap(); - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = services() - .globals - .add_signing_key(&k.server_name, k.clone())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::<BTreeMap<_, _>>(); - - pkm.insert(k.server_name.to_string(), result); - } - } - - if servers.is_empty() { - return Ok(()); - } - } - - let mut futures: FuturesUnordered<_> = servers - .into_iter() - .map(|(server, _)| async move { - ( - services().sending - .send_federation_request( - &server, - get_server_keys::v2::Request::new(), - ) - .await, - server, - ) - }) - .collect(); - - while let Some(result) = futures.next().await { - if let (Ok(get_keys_response), origin) = result { - let result: BTreeMap<_, _> = services() - .globals - .add_signing_key(&origin, get_keys_response.server_key.deserialize().unwrap())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); - } - } - - Ok(()) -} - -/// Returns Ok if the acl allows the server -fn acl_check(server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = match services() - .rooms - .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? - { - Some(acl) => acl, - None => return Ok(()), - }; - - let acl_event_content: RoomServerAclEventContent = - match serde_json::from_str(acl_event.content.get()) { - Ok(content) => content, - Err(_) => { - warn!("Invalid ACL event"); - return Ok(()); - } - }; - - if acl_event_content.is_allowed(server_name) { - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server was denied by ACL", - )) - } -} - #[cfg(test)] mod tests { use super::{add_port_to_hostname, get_ip_with_port, FedDest}; |