summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Cargo.lock54
-rw-r--r--Cargo.toml2
-rw-r--r--src/api/ruma_wrapper/axum.rs141
-rw-r--r--src/main.rs7
4 files changed, 131 insertions, 73 deletions
diff --git a/Cargo.lock b/Cargo.lock
index 9f62c18..4148394 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -89,9 +89,9 @@ checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "axum"
-version = "0.5.17"
+version = "0.6.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "acee9fd5073ab6b045a275b3e709c163dd36c90685219cb21804a147b58dba43"
+checksum = "f8175979259124331c1d7bf6586ee7e0da434155e4b2d48ec2c8386281d8df39"
dependencies = [
"async-trait",
"axum-core",
@@ -108,22 +108,22 @@ dependencies = [
"mime",
"percent-encoding",
"pin-project-lite",
+ "rustversion",
"serde",
"serde_json",
+ "serde_path_to_error",
"serde_urlencoded",
"sync_wrapper",
- "tokio",
"tower",
- "tower-http 0.3.5",
"tower-layer",
"tower-service",
]
[[package]]
name = "axum-core"
-version = "0.2.9"
+version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "37e5939e02c56fecd5c017c37df4238c0a839fa76b7f97acdd7efb804fd181cc"
+checksum = "759fa577a247914fd3f7f76d62972792636412fbfd634cd452f6a385a74d2d2c"
dependencies = [
"async-trait",
"bytes",
@@ -131,6 +131,7 @@ dependencies = [
"http",
"http-body",
"mime",
+ "rustversion",
"tower-layer",
"tower-service",
]
@@ -407,7 +408,7 @@ dependencies = [
"tikv-jemallocator",
"tokio",
"tower",
- "tower-http 0.4.1",
+ "tower-http",
"tracing",
"tracing-flame",
"tracing-opentelemetry",
@@ -1449,9 +1450,9 @@ checksum = "2532096657941c2fea9c289d370a250971c689d4f143798ff67113ec042024a5"
[[package]]
name = "matchit"
-version = "0.5.0"
+version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "73cbba799671b762df5a175adf59ce145165747bb891505c43d09aefbbf38beb"
+checksum = "b87248edafb776e59e6ee64a79086f65890d3510f2c656c000bf2a7e8a0aea40"
[[package]]
name = "memchr"
@@ -2364,6 +2365,12 @@ dependencies = [
]
[[package]]
+name = "rustversion"
+version = "1.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4f3208ce4d8448b3f3e7d168a73f5e0c43a61e32930de3bceeccedb388b6bf06"
+
+[[package]]
name = "ryu"
version = "1.0.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2468,6 +2475,15 @@ dependencies = [
]
[[package]]
+name = "serde_path_to_error"
+version = "0.1.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f7f05c1d5476066defcdfacce1f52fc3cae3af1d3089727100c02ae92e5abbe0"
+dependencies = [
+ "serde",
+]
+
+[[package]]
name = "serde_spanned"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -2954,7 +2970,6 @@ dependencies = [
"futures-util",
"pin-project",
"pin-project-lite",
- "tokio",
"tower-layer",
"tower-service",
"tracing",
@@ -2962,25 +2977,6 @@ dependencies = [
[[package]]
name = "tower-http"
-version = "0.3.5"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f873044bf02dd1e8239e9c1293ea39dad76dc594ec16185d0a1bf31d8dc8d858"
-dependencies = [
- "bitflags 1.3.2",
- "bytes",
- "futures-core",
- "futures-util",
- "http",
- "http-body",
- "http-range-header",
- "pin-project-lite",
- "tower",
- "tower-layer",
- "tower-service",
-]
-
-[[package]]
-name = "tower-http"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8bd22a874a2d0b70452d5597b12c537331d49060824a95f49f108994f94aa4c"
diff --git a/Cargo.toml b/Cargo.toml
index 9698caf..424007c 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -19,7 +19,7 @@ rust-version = "1.70.0"
[dependencies]
# Web framework
-axum = { version = "0.5.16", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true }
+axum = { version = "0.6.18", default-features = false, features = ["form", "headers", "http1", "http2", "json", "matched-path"], optional = true }
axum-server = { version = "0.5.1", features = ["tls-rustls"] }
tower = { version = "0.4.13", features = ["util"] }
tower-http = { version = "0.4.1", features = ["add-extension", "cors", "sensitive-headers", "trace", "util"] }
diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs
index 2d2af70..069e12b 100644
--- a/src/api/ruma_wrapper/axum.rs
+++ b/src/api/ruma_wrapper/axum.rs
@@ -3,18 +3,16 @@ use std::{collections::BTreeMap, iter::FromIterator, str};
use axum::{
async_trait,
body::{Full, HttpBody},
- extract::{
- rejection::TypedHeaderRejectionReason, FromRequest, Path, RequestParts, TypedHeader,
- },
+ extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader},
headers::{
authorization::{Bearer, Credentials},
Authorization,
},
response::{IntoResponse, Response},
- BoxError,
+ BoxError, RequestExt, RequestPartsExt,
};
-use bytes::{BufMut, Bytes, BytesMut};
-use http::StatusCode;
+use bytes::{Buf, BufMut, Bytes, BytesMut};
+use http::{Request, StatusCode};
use ruma::{
api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse},
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId,
@@ -26,27 +24,44 @@ use super::{Ruma, RumaResponse};
use crate::{services, Error, Result};
#[async_trait]
-impl<T, B> FromRequest<B> for Ruma<T>
+impl<T, S, B> FromRequest<S, B> for Ruma<T>
where
T: IncomingRequest,
- B: HttpBody + Send,
+ B: HttpBody + Send + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
type Rejection = Error;
- async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
+ async fn from_request(req: Request<B>, _state: &S) -> Result<Self, Self::Rejection> {
#[derive(Deserialize)]
struct QueryParams {
access_token: Option<String>,
user_id: Option<String>,
}
+ let (mut parts, mut body) = match req.with_limited_body() {
+ Ok(limited_req) => {
+ let (parts, body) = limited_req.into_parts();
+ let body = to_bytes(body)
+ .await
+ .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
+ (parts, body)
+ }
+ Err(original_req) => {
+ let (parts, body) = original_req.into_parts();
+ let body = to_bytes(body)
+ .await
+ .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
+ (parts, body)
+ }
+ };
+
let metadata = T::METADATA;
- let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
- let path_params = Path::<Vec<String>>::from_request(req).await?;
+ let auth_header: Option<TypedHeader<Authorization<Bearer>>> = parts.extract().await?;
+ let path_params: Path<Vec<String>> = parts.extract().await?;
- let query = req.uri().query().unwrap_or_default();
+ let query = parts.uri.query().unwrap_or_default();
let query_params: QueryParams = match serde_html_form::from_str(query) {
Ok(params) => params,
Err(e) => {
@@ -63,10 +78,6 @@ where
None => query_params.access_token.as_deref(),
};
- let mut body = Bytes::from_request(req)
- .await
- .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
-
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = services().appservice.all().unwrap();
@@ -138,24 +149,24 @@ where
}
}
AuthScheme::ServerSignatures => {
- let TypedHeader(Authorization(x_matrix)) =
- TypedHeader::<Authorization<XMatrix>>::from_request(req)
- .await
- .map_err(|e| {
- warn!("Missing or invalid Authorization header: {}", e);
-
- let msg = match e.reason() {
- TypedHeaderRejectionReason::Missing => {
- "Missing Authorization header."
- }
- TypedHeaderRejectionReason::Error(_) => {
- "Invalid X-Matrix signatures."
- }
- _ => "Unknown header-related error",
- };
-
- Error::BadRequest(ErrorKind::Forbidden, msg)
- })?;
+ let TypedHeader(Authorization(x_matrix)) = parts
+ .extract::<TypedHeader<Authorization<XMatrix>>>()
+ .await
+ .map_err(|e| {
+ warn!("Missing or invalid Authorization header: {}", e);
+
+ let msg = match e.reason() {
+ TypedHeaderRejectionReason::Missing => {
+ "Missing Authorization header."
+ }
+ TypedHeaderRejectionReason::Error(_) => {
+ "Invalid X-Matrix signatures."
+ }
+ _ => "Unknown header-related error",
+ };
+
+ Error::BadRequest(ErrorKind::Forbidden, msg)
+ })?;
let origin_signatures = BTreeMap::from_iter([(
x_matrix.key.clone(),
@@ -170,11 +181,11 @@ where
let mut request_map = BTreeMap::from_iter([
(
"method".to_owned(),
- CanonicalJsonValue::String(req.method().to_string()),
+ CanonicalJsonValue::String(parts.method.to_string()),
),
(
"uri".to_owned(),
- CanonicalJsonValue::String(req.uri().to_string()),
+ CanonicalJsonValue::String(parts.uri.to_string()),
),
(
"origin".to_owned(),
@@ -224,7 +235,7 @@ where
x_matrix.origin, e, request_map
);
- if req.uri().to_string().contains('@') {
+ if parts.uri.to_string().contains('@') {
warn!(
"Request uri contained '@' character. Make sure your \
reverse proxy gives Conduit the raw uri (apache: use \
@@ -243,8 +254,8 @@ where
}
};
- let mut http_request = http::Request::builder().uri(req.uri()).method(req.method());
- *http_request.headers_mut().unwrap() = req.headers().clone();
+ let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method);
+ *http_request.headers_mut().unwrap() = parts.headers;
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| {
@@ -364,3 +375,55 @@ impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
}
}
}
+
+// copied from hyper under the following license:
+// Copyright (c) 2014-2021 Sean McArthur
+
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+pub(crate) async fn to_bytes<T>(body: T) -> Result<Bytes, T::Error>
+where
+ T: HttpBody,
+{
+ futures_util::pin_mut!(body);
+
+ // If there's only 1 chunk, we can just return Buf::to_bytes()
+ let mut first = if let Some(buf) = body.data().await {
+ buf?
+ } else {
+ return Ok(Bytes::new());
+ };
+
+ let second = if let Some(buf) = body.data().await {
+ buf?
+ } else {
+ return Ok(first.copy_to_bytes(first.remaining()));
+ };
+
+ // With more than 1 buf, we gotta flatten into a Vec first.
+ let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize;
+ let mut vec = Vec::with_capacity(cap);
+ vec.put(first);
+ vec.put(second);
+
+ while let Some(buf) = body.data().await {
+ vec.put(buf?);
+ }
+
+ Ok(vec.into())
+}
diff --git a/src/main.rs b/src/main.rs
index f9f88f4..e0f84d9 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -10,8 +10,7 @@
use std::{future::Future, io, net::SocketAddr, sync::atomic, time::Duration};
use axum::{
- extract::{DefaultBodyLimit, FromRequest, MatchedPath},
- handler::Handler,
+ extract::{DefaultBodyLimit, FromRequestParts, MatchedPath},
response::IntoResponse,
routing::{get, on, MethodFilter},
Router,
@@ -421,7 +420,7 @@ fn routes() -> Router {
"/_matrix/client/v3/rooms/:room_id/initialSync",
get(initial_sync),
)
- .fallback(not_found.into_service())
+ .fallback(not_found)
}
async fn shutdown_signal(handle: ServerHandle) {
@@ -505,7 +504,7 @@ macro_rules! impl_ruma_handler {
Fut: Future<Output = Result<Req::OutgoingResponse, E>>
+ Send,
E: IntoResponse,
- $( $ty: FromRequest<axum::body::Body> + Send + 'static, )*
+ $( $ty: FromRequestParts<()> + Send + 'static, )*
{
fn add_to_router(self, mut router: Router) -> Router {
let meta = Req::METADATA;