summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--conduit-example.toml2
-rw-r--r--src/database.rs121
-rw-r--r--src/database/proxy.rs146
3 files changed, 151 insertions, 118 deletions
diff --git a/conduit-example.toml b/conduit-example.toml
index 66c105b..db0bbb7 100644
--- a/conduit-example.toml
+++ b/conduit-example.toml
@@ -41,3 +41,5 @@ trusted_servers = ["matrix.org"]
#workers = 4 # default: cpu core count * 2
address = "127.0.0.1" # This makes sure Conduit can only be reached using the reverse proxy
+
+proxy = "none" # more examples can be found at src/database/proxy.rs:6
diff --git a/src/database.rs b/src/database.rs
index 64b5ee3..0ea4d78 100644
--- a/src/database.rs
+++ b/src/database.rs
@@ -6,6 +6,7 @@ pub mod appservice;
pub mod globals;
pub mod key_backups;
pub mod media;
+pub mod proxy;
pub mod pusher;
pub mod rooms;
pub mod sending;
@@ -28,6 +29,8 @@ use std::{
};
use tokio::sync::Semaphore;
+use self::proxy::ProxyConfig;
+
#[derive(Clone, Debug, Deserialize)]
pub struct Config {
server_name: Box<ServerName>,
@@ -85,124 +88,6 @@ pub type Engine = abstraction::SledEngine;
#[cfg(feature = "rocksdb")]
pub type Engine = abstraction::RocksDbEngine;
-#[derive(Clone, Debug, Deserialize)]
-#[serde(rename_all = "snake_case")]
-pub enum ProxyConfig {
- None,
- Global {
- #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
- url: reqwest::Url,
- },
- ByDomain(Vec<PartialProxyConfig>),
-}
-impl ProxyConfig {
- pub fn to_proxy(&self) -> Result<Option<reqwest::Proxy>> {
- Ok(match self.clone() {
- ProxyConfig::None => None,
- ProxyConfig::Global { url } => Some(reqwest::Proxy::all(url)?),
- ProxyConfig::ByDomain(proxies) => Some(reqwest::Proxy::custom(move |url| {
- proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
- })),
- })
- }
-}
-impl Default for ProxyConfig {
- fn default() -> Self {
- ProxyConfig::None
- }
-}
-
-#[derive(Clone, Debug, Deserialize)]
-pub struct PartialProxyConfig {
- #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
- url: reqwest::Url,
- #[serde(default)]
- include: Vec<WildCardedDomain>,
- #[serde(default)]
- exclude: Vec<WildCardedDomain>,
-}
-impl PartialProxyConfig {
- pub fn for_url(&self, url: &reqwest::Url) -> Option<&reqwest::Url> {
- let domain = url.domain()?;
- let mut included_because = None; // most specific reason it was included
- let mut excluded_because = None; // most specific reason it was excluded
- if self.include.is_empty() {
- // treat empty include list as `*`
- included_because = Some(&WildCardedDomain::WildCard)
- }
- for wc_domain in &self.include {
- if wc_domain.matches(domain) {
- match included_because {
- Some(prev) if !wc_domain.more_specific_than(prev) => (),
- _ => included_because = Some(wc_domain),
- }
- }
- }
- for wc_domain in &self.exclude {
- if wc_domain.matches(domain) {
- match excluded_because {
- Some(prev) if !wc_domain.more_specific_than(prev) => (),
- _ => excluded_because = Some(wc_domain),
- }
- }
- }
- match (included_because, excluded_because) {
- (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
- (Some(_), None) => Some(&self.url),
- _ => None,
- }
- }
-}
-
-/// A domain name, that optionally allows a * as its first subdomain.
-#[derive(Clone, Debug)]
-pub enum WildCardedDomain {
- WildCard,
- WildCarded(String),
- Exact(String),
-}
-impl WildCardedDomain {
- pub fn matches(&self, domain: &str) -> bool {
- match self {
- WildCardedDomain::WildCard => true,
- WildCardedDomain::WildCarded(d) => domain.ends_with(d),
- WildCardedDomain::Exact(d) => domain == d,
- }
- }
- pub fn more_specific_than(&self, other: &Self) -> bool {
- match (self, other) {
- (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
- (_, WildCardedDomain::WildCard) => true,
- (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
- (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
- a != b && a.ends_with(b)
- }
- _ => false,
- }
- }
-}
-impl std::str::FromStr for WildCardedDomain {
- type Err = std::convert::Infallible;
- fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- // maybe do some domain validation?
- Ok(if s.starts_with("*.") {
- WildCardedDomain::WildCarded(s[1..].to_owned())
- } else if s == "*" {
- WildCardedDomain::WildCarded("".to_owned())
- } else {
- WildCardedDomain::Exact(s.to_owned())
- })
- }
-}
-impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
- fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
- where
- D: serde::de::Deserializer<'de>,
- {
- crate::utils::deserialize_from_str(deserializer)
- }
-}
-
pub struct Database {
pub globals: globals::Globals,
pub users: users::Users,
diff --git a/src/database/proxy.rs b/src/database/proxy.rs
new file mode 100644
index 0000000..78e9d2b
--- /dev/null
+++ b/src/database/proxy.rs
@@ -0,0 +1,146 @@
+use reqwest::{Proxy, Url};
+use serde::Deserialize;
+
+use crate::Result;
+
+/// ## Examples:
+/// - No proxy (default):
+/// ```toml
+/// proxy ="none"
+/// ```
+/// - Global proxy
+/// ```toml
+/// [proxy]
+/// global = { url = "socks5h://localhost:9050" }
+/// ```
+/// - Proxy some domains
+/// ```toml
+/// [proxy]
+/// [[proxy.by_domain]]
+/// url = "socks5h://localhost:9050"
+/// include = ["*.onion", "matrix.myspecial.onion"]
+/// exclude = ["*.myspecial.onion"]
+/// ```
+/// ## Include vs. Exclude
+/// If include is an empty list, it is assumed to be `["*"]`.
+///
+/// If a domain matches both the exclude and include list, the proxy will only be used if it was
+/// included because of a more specific rule than it was excluded. In the above example, the proxy
+/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
+#[derive(Clone, Debug, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum ProxyConfig {
+ None,
+ Global {
+ #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
+ url: Url,
+ },
+ ByDomain(Vec<PartialProxyConfig>),
+}
+impl ProxyConfig {
+ pub fn to_proxy(&self) -> Result<Option<Proxy>> {
+ Ok(match self.clone() {
+ ProxyConfig::None => None,
+ ProxyConfig::Global { url } => Some(Proxy::all(url)?),
+ ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
+ proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
+ })),
+ })
+ }
+}
+impl Default for ProxyConfig {
+ fn default() -> Self {
+ ProxyConfig::None
+ }
+}
+
+#[derive(Clone, Debug, Deserialize)]
+pub struct PartialProxyConfig {
+ #[serde(deserialize_with = "crate::utils::deserialize_from_str")]
+ url: Url,
+ #[serde(default)]
+ include: Vec<WildCardedDomain>,
+ #[serde(default)]
+ exclude: Vec<WildCardedDomain>,
+}
+impl PartialProxyConfig {
+ pub fn for_url(&self, url: &Url) -> Option<&Url> {
+ let domain = url.domain()?;
+ let mut included_because = None; // most specific reason it was included
+ let mut excluded_because = None; // most specific reason it was excluded
+ if self.include.is_empty() {
+ // treat empty include list as `*`
+ included_because = Some(&WildCardedDomain::WildCard)
+ }
+ for wc_domain in &self.include {
+ if wc_domain.matches(domain) {
+ match included_because {
+ Some(prev) if !wc_domain.more_specific_than(prev) => (),
+ _ => included_because = Some(wc_domain),
+ }
+ }
+ }
+ for wc_domain in &self.exclude {
+ if wc_domain.matches(domain) {
+ match excluded_because {
+ Some(prev) if !wc_domain.more_specific_than(prev) => (),
+ _ => excluded_because = Some(wc_domain),
+ }
+ }
+ }
+ match (included_because, excluded_because) {
+ (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
+ (Some(_), None) => Some(&self.url),
+ _ => None,
+ }
+ }
+}
+
+/// A domain name, that optionally allows a * as its first subdomain.
+#[derive(Clone, Debug)]
+pub enum WildCardedDomain {
+ WildCard,
+ WildCarded(String),
+ Exact(String),
+}
+impl WildCardedDomain {
+ pub fn matches(&self, domain: &str) -> bool {
+ match self {
+ WildCardedDomain::WildCard => true,
+ WildCardedDomain::WildCarded(d) => domain.ends_with(d),
+ WildCardedDomain::Exact(d) => domain == d,
+ }
+ }
+ pub fn more_specific_than(&self, other: &Self) -> bool {
+ match (self, other) {
+ (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
+ (_, WildCardedDomain::WildCard) => true,
+ (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
+ (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
+ a != b && a.ends_with(b)
+ }
+ _ => false,
+ }
+ }
+}
+impl std::str::FromStr for WildCardedDomain {
+ type Err = std::convert::Infallible;
+ fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+ // maybe do some domain validation?
+ Ok(if s.starts_with("*.") {
+ WildCardedDomain::WildCarded(s[1..].to_owned())
+ } else if s == "*" {
+ WildCardedDomain::WildCarded("".to_owned())
+ } else {
+ WildCardedDomain::Exact(s.to_owned())
+ })
+ }
+}
+impl<'de> serde::de::Deserialize<'de> for WildCardedDomain {
+ fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
+ where
+ D: serde::de::Deserializer<'de>,
+ {
+ crate::utils::deserialize_from_str(deserializer)
+ }
+}