summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonas Platte <jplatte+git@posteo.de>2022-01-22 18:38:39 +0100
committerJonas Platte <jplatte+git@posteo.de>2022-02-12 12:56:08 +0100
commitc8951a1d9cc05a8c138be06f520a78b4cbb053c7 (patch)
tree8aab4bb1949a28983ce3ee971fd6215ee0403098
parent5fa9190117805ff1040c69b65a3b9caacb6c965b (diff)
downloadconduit-c8951a1d9cc05a8c138be06f520a78b4cbb053c7.zip
Use axum-server for direct TLS support
-rw-r--r--Cargo.lock28
-rw-r--r--Cargo.toml2
-rw-r--r--src/config.rs8
-rw-r--r--src/main.rs29
4 files changed, 57 insertions, 10 deletions
diff --git a/Cargo.lock b/Cargo.lock
index f84c982..41105b3 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -59,6 +59,12 @@ dependencies = [
]
[[package]]
+name = "arc-swap"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c5d78ce20460b82d3fa150275ed9d55e21064fc7951177baacf86a145c4a4b1f"
+
+[[package]]
name = "arrayref"
version = "0.3.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -163,6 +169,26 @@ dependencies = [
]
[[package]]
+name = "axum-server"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f9cfd9dbe28ebde5c0460067ea27c6f3b1d514b699c4e0a5aab0fb63e452a8a8"
+dependencies = [
+ "arc-swap",
+ "bytes",
+ "futures-util",
+ "http",
+ "http-body",
+ "hyper",
+ "pin-project-lite",
+ "rustls",
+ "rustls-pemfile",
+ "tokio",
+ "tokio-rustls",
+ "tower-service",
+]
+
+[[package]]
name = "base64"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -365,6 +391,7 @@ name = "conduit"
version = "0.3.0"
dependencies = [
"axum",
+ "axum-server",
"base64 0.13.0",
"bytes",
"clap",
@@ -375,7 +402,6 @@ dependencies = [
"heed",
"hmac",
"http",
- "hyper",
"image",
"jsonwebtoken",
"lru-cache",
diff --git a/Cargo.toml b/Cargo.toml
index 5fb75dc..6dedfa8 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,7 +15,7 @@ edition = "2021"
[dependencies]
# Web framework
axum = { version = "0.4.4", features = ["headers"], optional = true }
-hyper = "0.14.16"
+axum-server = { version = "0.3.3", features = ["tls-rustls"] }
tower = { version = "0.4.11", features = ["util"] }
tower-http = { version = "0.2.1", features = ["add-extension", "cors", "compression-full", "sensitive-headers", "trace", "util"] }
diff --git a/src/config.rs b/src/config.rs
index 48ac981..155704b 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -17,6 +17,8 @@ pub struct Config {
pub address: IpAddr,
#[serde(default = "default_port")]
pub port: u16,
+ pub tls: Option<TlsConfig>,
+
pub server_name: Box<ServerName>,
#[serde(default = "default_database_backend")]
pub database_backend: String,
@@ -69,6 +71,12 @@ pub struct Config {
pub catchall: BTreeMap<String, IgnoredAny>,
}
+#[derive(Clone, Debug, Deserialize)]
+pub struct TlsConfig {
+ pub certs: String,
+ pub key: String,
+}
+
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
impl Config {
diff --git a/src/main.rs b/src/main.rs
index 40122cf..22ddf3e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,7 +7,7 @@
#![allow(clippy::suspicious_else_formatting)]
#![deny(clippy::dbg_macro)]
-use std::{future::Future, net::SocketAddr, sync::Arc, time::Duration};
+use std::{future::Future, io, net::SocketAddr, sync::Arc, time::Duration};
use axum::{
extract::{FromRequest, MatchedPath},
@@ -15,6 +15,7 @@ use axum::{
routing::{get, on, MethodFilter},
Router,
};
+use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
use figment::{
providers::{Env, Format, Toml},
Figment,
@@ -117,8 +118,8 @@ async fn main() {
}
}
-async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> hyper::Result<()> {
- let listen_addr = SocketAddr::from((config.address, config.port));
+async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> io::Result<()> {
+ let addr = SocketAddr::from((config.address, config.port));
let x_requested_with = HeaderName::from_static("x-requested-with");
@@ -157,10 +158,20 @@ async fn run_server(config: &Config, db: Arc<RwLock<Database>>) -> hyper::Result
)
.add_extension(db.clone());
- axum::Server::bind(&listen_addr)
- .serve(routes().layer(middlewares).into_make_service())
- .with_graceful_shutdown(shutdown_signal())
- .await?;
+ let app = routes().layer(middlewares).into_make_service();
+ let handle = ServerHandle::new();
+
+ tokio::spawn(shutdown_signal(handle.clone()));
+
+ match &config.tls {
+ Some(tls) => {
+ let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
+ bind_rustls(addr, conf).handle(handle).serve(app).await?;
+ }
+ None => {
+ bind(addr).handle(handle).serve(app).await?;
+ }
+ }
// After serve exits and before exiting, shutdown the DB
Database::on_shutdown(db).await;
@@ -312,7 +323,7 @@ fn routes() -> Router {
.ruma_route(server_server::claim_keys_route)
}
-async fn shutdown_signal() {
+async fn shutdown_signal(handle: ServerHandle) {
let ctrl_c = async {
signal::ctrl_c()
.await
@@ -334,6 +345,8 @@ async fn shutdown_signal() {
_ = ctrl_c => {},
_ = terminate => {},
}
+
+ handle.graceful_shutdown(Some(Duration::from_secs(30)));
}
trait RouterExt {