diff options
Diffstat (limited to 'src/api/ruma_wrapper/axum.rs')
-rw-r--r-- | src/api/ruma_wrapper/axum.rs | 363 |
1 files changed, 363 insertions, 0 deletions
diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs new file mode 100644 index 0000000..ed28f9d --- /dev/null +++ b/src/api/ruma_wrapper/axum.rs @@ -0,0 +1,363 @@ +use std::{collections::BTreeMap, iter::FromIterator, str}; + +use axum::{ + async_trait, + body::{Full, HttpBody}, + extract::{ + rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader, + }, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + response::{IntoResponse, Response}, + BoxError, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use http::StatusCode; +use ruma::{ + api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, +}; +use serde::Deserialize; +use tracing::{debug, error, warn}; + +use super::{Ruma, RumaResponse}; +use crate::{services, Error, Result}; + +#[async_trait] +impl<T, B> FromRequest<B> for Ruma<T> +where + T: IncomingRequest, + B: HttpBody + Send, + B::Data: Send, + B::Error: Into<BoxError>, +{ + type Rejection = Error; + + async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + #[derive(Deserialize)] + struct QueryParams { + access_token: Option<String>, + user_id: Option<String>, + } + + let metadata = T::METADATA; + let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; + let path_params = Path::<Vec<String>>::from_request(req).await?; + + let query = req.uri().query().unwrap_or_default(); + let query_params: QueryParams = match ruma::serde::urlencoded::from_str(query) { + Ok(params) => params, + Err(e) => { + error!(%query, "Failed to deserialize query parameters: {}", e); + return Err(Error::BadRequest( + ErrorKind::Unknown, + "Failed to read query parameters", + )); + } + }; + + let token = match &auth_header { + Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), + None => query_params.access_token.as_deref(), + }; + + let mut body = Bytes::from_request(req) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + + let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); + + let appservices = services().appservice.all().unwrap(); + let appservice_registration = appservices.iter().find(|(_id, registration)| { + registration + .get("as_token") + .and_then(|as_token| as_token.as_str()) + .map_or(false, |as_token| token == Some(as_token)) + }); + + let (sender_user, sender_device, sender_servername, from_appservice) = + if let Some((_id, registration)) = appservice_registration { + match metadata.authentication { + AuthScheme::AccessToken => { + let user_id = query_params.user_id.map_or_else( + || { + UserId::parse_with_server_name( + registration + .get("sender_localpart") + .unwrap() + .as_str() + .unwrap(), + services().globals.server_name(), + ) + .unwrap() + }, + |s| UserId::parse(s).unwrap(), + ); + + if !services().users.exists(&user_id).unwrap() { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "User does not exist.", + )); + } + + // TODO: Check if appservice is allowed to be that user + (Some(user_id), None, None, true) + } + AuthScheme::ServerSignatures => (None, None, None, true), + AuthScheme::None => (None, None, None, true), + } + } else { + match metadata.authentication { + AuthScheme::AccessToken => { + let token = match token { + Some(token) => token, + _ => { + return Err(Error::BadRequest( + ErrorKind::MissingToken, + "Missing access token.", + )) + } + }; + + match services().users.find_from_token(token).unwrap() { + None => { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { soft_logout: false }, + "Unknown access token.", + )) + } + Some((user_id, device_id)) => ( + Some(user_id), + Some(OwnedDeviceId::from(device_id)), + None, + false, + ), + } + } + AuthScheme::ServerSignatures => { + let TypedHeader(Authorization(x_matrix)) = + TypedHeader::<Authorization<XMatrix>>::from_request(req) + .await + .map_err(|e| { + warn!("Missing or invalid Authorization header: {}", e); + + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => { + "Missing Authorization header." + } + TypedHeaderRejectionReason::Error(_) => { + "Invalid X-Matrix signatures." + } + _ => "Unknown header-related error", + }; + + Error::BadRequest(ErrorKind::Forbidden, msg) + })?; + + let origin_signatures = BTreeMap::from_iter([( + x_matrix.key.clone(), + CanonicalJsonValue::String(x_matrix.sig), + )]); + + let signatures = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + CanonicalJsonValue::Object(origin_signatures), + )]); + + let mut request_map = BTreeMap::from_iter([ + ( + "method".to_owned(), + CanonicalJsonValue::String(req.method().to_string()), + ), + ( + "uri".to_owned(), + CanonicalJsonValue::String(req.uri().to_string()), + ), + ( + "origin".to_owned(), + CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), + ), + ( + "destination".to_owned(), + CanonicalJsonValue::String( + services().globals.server_name().as_str().to_owned(), + ), + ), + ( + "signatures".to_owned(), + CanonicalJsonValue::Object(signatures), + ), + ]); + + if let Some(json_body) = &json_body { + request_map.insert("content".to_owned(), json_body.clone()); + }; + + let keys_result = services() + .rooms + .event_handler + .fetch_signing_keys(&x_matrix.origin, vec![x_matrix.key.to_owned()]) + .await; + + let keys = match keys_result { + Ok(b) => b, + Err(e) => { + warn!("Failed to fetch signing keys: {}", e); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Failed to fetch signing keys.", + )); + } + }; + + let pub_key_map = + BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); + + match ruma::signatures::verify_json(&pub_key_map, &request_map) { + Ok(()) => (None, None, Some(x_matrix.origin), false), + Err(e) => { + warn!( + "Failed to verify json request from {}: {}\n{:?}", + x_matrix.origin, e, request_map + ); + + if req.uri().to_string().contains('@') { + warn!( + "Request uri contained '@' character. Make sure your \ + reverse proxy gives Conduit the raw uri (apache: use \ + nocanon)" + ); + } + + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Failed to verify X-Matrix signatures.", + )); + } + } + } + AuthScheme::None => (None, None, None, false), + } + }; + + let mut http_request = http::Request::builder().uri(req.uri()).method(req.method()); + *http_request.headers_mut().unwrap() = req.headers().clone(); + + if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { + let user_id = sender_user.clone().unwrap_or_else(|| { + UserId::parse_with_server_name("", services().globals.server_name()) + .expect("we know this is valid") + }); + + let uiaa_request = json_body + .get("auth") + .and_then(|auth| auth.as_object()) + .and_then(|auth| auth.get("session")) + .and_then(|session| session.as_str()) + .and_then(|session| { + services().uiaa.get_uiaa_request( + &user_id, + &sender_device.clone().unwrap_or_else(|| "".into()), + session, + ) + }); + + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); + } + } + + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); + body = buf.into_inner().freeze(); + } + + let http_request = http_request.body(&*body).unwrap(); + + debug!("{:?}", http_request); + + let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { + warn!("{:?}\n{:?}", e, json_body); + Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") + })?; + + Ok(Ruma { + body, + sender_user, + sender_device, + sender_servername, + from_appservice, + json_body, + }) + } +} + +struct XMatrix { + origin: OwnedServerName, + key: String, // KeyName? + sig: String, +} + +impl Credentials for XMatrix { + const SCHEME: &'static str = "X-Matrix"; + + fn decode(value: &http::HeaderValue) -> Option<Self> { + debug_assert!( + value.as_bytes().starts_with(b"X-Matrix "), + "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", + ); + + let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) + .ok()? + .trim_start(); + + let mut origin = None; + let mut key = None; + let mut sig = None; + + for entry in parameters.split_terminator(',') { + let (name, value) = entry.split_once('=')?; + + // It's not at all clear why some fields are quoted and others not in the spec, + // let's simply accept either form for every field. + let value = value + .strip_prefix('"') + .and_then(|rest| rest.strip_suffix('"')) + .unwrap_or(value); + + // FIXME: Catch multiple fields of the same name + match name { + "origin" => origin = Some(value.try_into().ok()?), + "key" => key = Some(value.to_owned()), + "sig" => sig = Some(value.to_owned()), + _ => debug!( + "Unexpected field `{}` in X-Matrix Authorization header", + name + ), + } + } + + Some(Self { + origin: origin?, + key: key?, + sig: sig?, + }) + } + + fn encode(&self) -> http::HeaderValue { + todo!() + } +} + +impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { + fn into_response(self) -> Response { + match self.0.try_into_http_response::<BytesMut>() { + Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } +} |