diff options
Diffstat (limited to 'src/api/ruma_wrapper/axum.rs')
-rw-r--r-- | src/api/ruma_wrapper/axum.rs | 146 |
1 files changed, 105 insertions, 41 deletions
diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index ed28f9d..bbd4861 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -3,18 +3,16 @@ use std::{collections::BTreeMap, iter::FromIterator, str}; use axum::{ async_trait, body::{Full, HttpBody}, - extract::{ - rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader, - }, + extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, headers::{ authorization::{Bearer, Credentials}, Authorization, }, response::{IntoResponse, Response}, - BoxError, + BoxError, RequestExt, RequestPartsExt, }; -use bytes::{BufMut, Bytes, BytesMut}; -use http::StatusCode; +use bytes::{Buf, BufMut, Bytes, BytesMut}; +use http::{Request, StatusCode}; use ruma::{ api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, @@ -26,28 +24,45 @@ use super::{Ruma, RumaResponse}; use crate::{services, Error, Result}; #[async_trait] -impl<T, B> FromRequest<B> for Ruma<T> +impl<T, S, B> FromRequest<S, B> for Ruma<T> where T: IncomingRequest, - B: HttpBody + Send, + B: HttpBody + Send + 'static, B::Data: Send, B::Error: Into<BoxError>, { type Rejection = Error; - async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> { + async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> { #[derive(Deserialize)] struct QueryParams { access_token: Option<String>, user_id: Option<String>, } + let (mut parts, mut body) = match req.with_limited_body() { + Ok(limited_req) => { + let (parts, body) = limited_req.into_parts(); + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + } + Err(original_req) => { + let (parts, body) = original_req.into_parts(); + let body = to_bytes(body) + .await + .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + } + }; + let metadata = T::METADATA; - let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; - let path_params = Path::<Vec<String>>::from_request(req).await?; + let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?; + let path_params: Path<Vec<String>> = parts.extract().await?; - let query = req.uri().query().unwrap_or_default(); - let query_params: QueryParams = match ruma::serde::urlencoded::from_str(query) { + let query = parts.uri.query().unwrap_or_default(); + let query_params: QueryParams = match serde_html_form::from_str(query) { Ok(params) => params, Err(e) => { error!(%query, "Failed to deserialize query parameters: {}", e); @@ -63,10 +78,6 @@ where None => query_params.access_token.as_deref(), }; - let mut body = Bytes::from_request(req) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; - let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let appservices = services().appservice.all().unwrap(); @@ -138,24 +149,24 @@ where } } AuthScheme::ServerSignatures => { - let TypedHeader(Authorization(x_matrix)) = - TypedHeader::<Authorization<XMatrix>>::from_request(req) - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {}", e); - - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => { - "Missing Authorization header." - } - TypedHeaderRejectionReason::Error(_) => { - "Invalid X-Matrix signatures." - } - _ => "Unknown header-related error", - }; - - Error::BadRequest(ErrorKind::Forbidden, msg) - })?; + let TypedHeader(Authorization(x_matrix)) = parts + .extract::<TypedHeader<Authorization<XMatrix>>>() + .await + .map_err(|e| { + warn!("Missing or invalid Authorization header: {}", e); + + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => { + "Missing Authorization header." + } + TypedHeaderRejectionReason::Error(_) => { + "Invalid X-Matrix signatures." + } + _ => "Unknown header-related error", + }; + + Error::BadRequest(ErrorKind::Forbidden, msg) + })?; let origin_signatures = BTreeMap::from_iter([( x_matrix.key.clone(), @@ -170,11 +181,11 @@ where let mut request_map = BTreeMap::from_iter([ ( "method".to_owned(), - CanonicalJsonValue::String(req.method().to_string()), + CanonicalJsonValue::String(parts.method.to_string()), ), ( "uri".to_owned(), - CanonicalJsonValue::String(req.uri().to_string()), + CanonicalJsonValue::String(parts.uri.to_string()), ), ( "origin".to_owned(), @@ -224,7 +235,7 @@ where x_matrix.origin, e, request_map ); - if req.uri().to_string().contains('@') { + if parts.uri.to_string().contains('@') { warn!( "Request uri contained '@' character. Make sure your \ reverse proxy gives Conduit the raw uri (apache: use \ @@ -243,8 +254,8 @@ where } }; - let mut http_request = http::Request::builder().uri(req.uri()).method(req.method()); - *http_request.headers_mut().unwrap() = req.headers().clone(); + let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); + *http_request.headers_mut().unwrap() = parts.headers; if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { let user_id = sender_user.clone().unwrap_or_else(|| { @@ -281,7 +292,8 @@ where debug!("{:?}", http_request); let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { - warn!("{:?}\n{:?}", e, json_body); + warn!("try_from_http_request failed: {:?}", e); + debug!("JSON body: {:?}", json_body); Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") })?; @@ -361,3 +373,55 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> { } } } + +// copied from hyper under the following license: +// Copyright (c) 2014-2021 Sean McArthur + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. +pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error> +where + T: HttpBody, +{ + futures_util::pin_mut!(body); + + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; + + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.copy_to_bytes(first.remaining())); + }; + + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); + + while let Some(buf) = body.data().await { + vec.put(buf?); + } + + Ok(vec.into()) +} |