Browse Source

add tor support

merge-requests/52/head
Aiden McClelland 5 years ago
parent
commit
6f8adc86a5
  1. 13
      Cargo.lock
  2. 3
      Cargo.toml
  3. 37
      src/database.rs
  4. 49
      src/database/globals.rs
  5. 23
      src/server_server.rs
  6. 28
      src/utils.rs

13
Cargo.lock generated

@ -1508,6 +1508,7 @@ dependencies = [ @@ -1508,6 +1508,7 @@ dependencies = [
"serde_urlencoded",
"tokio",
"tokio-native-tls",
"tokio-socks",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
@ -2336,6 +2337,18 @@ dependencies = [ @@ -2336,6 +2337,18 @@ dependencies = [
"webpki",
]
[[package]]
name = "tokio-socks"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51165dfa029d2a65969413a6cc96f354b86b464498702f174a4efa13608fd8c0"
dependencies = [
"either",
"futures-util",
"thiserror",
"tokio",
]
[[package]]
name = "tokio-util"
version = "0.6.3"

3
Cargo.toml

@ -71,9 +71,10 @@ tracing-opentelemetry = "0.11.0" @@ -71,9 +71,10 @@ tracing-opentelemetry = "0.11.0"
opentelemetry-jaeger = "0.11.0"
[features]
default = ["conduit_bin"]
default = ["conduit_bin", "tor"]
conduit_bin = [] # TODO: add rocket to this when it is optional
tls_vendored = ["reqwest/native-tls-vendored"]
tor = ["reqwest/socks"]
[[bin]]
name = "conduit"

37
src/database.rs

@ -17,9 +17,11 @@ use log::info; @@ -17,9 +17,11 @@ use log::info;
use rocket::futures::{self, channel::mpsc};
use ruma::{DeviceId, ServerName, UserId};
use serde::Deserialize;
use std::collections::HashMap;
use std::fs::remove_dir_all;
use std::sync::{Arc, RwLock};
use std::{
collections::HashMap,
fs::remove_dir_all,
sync::{Arc, RwLock},
};
use tokio::sync::Semaphore;
#[derive(Clone, Deserialize)]
@ -40,6 +42,10 @@ pub struct Config { @@ -40,6 +42,10 @@ pub struct Config {
allow_federation: bool,
#[serde(default = "false_fn")]
pub allow_jaeger: bool,
#[cfg(feature = "tor")]
#[serde(default)]
#[serde(flatten)]
tor_federation: TorFederation,
jwt_secret: Option<String>,
}
@ -63,6 +69,31 @@ fn default_max_concurrent_requests() -> u16 { @@ -63,6 +69,31 @@ fn default_max_concurrent_requests() -> u16 {
4
}
#[cfg(feature = "tor")]
#[derive(Clone, Deserialize)]
#[serde(rename = "snake_case")]
#[serde(tag = "tor_federation")]
pub enum TorFederation {
Disabled,
Enabled {
#[serde(deserialize_with = "crate::utils::deserialize_from_str")]
tor_proxy: reqwest::Url,
tor_only: bool,
},
}
#[cfg(feature = "tor")]
impl TorFederation {
pub fn enabled(&self) -> bool {
matches!(self, &TorFederation::Enabled { .. })
}
}
#[cfg(feature = "tor")]
impl Default for TorFederation {
fn default() -> Self {
TorFederation::Disabled
}
}
#[derive(Clone)]
pub struct Database {
pub globals: globals::Globals,

49
src/database/globals.rs

@ -1,10 +1,11 @@ @@ -1,10 +1,11 @@
use crate::{database::Config, utils, Error, Result};
use log::error;
use ruma::ServerName;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use std::time::Duration;
use std::{
collections::HashMap,
sync::{Arc, RwLock},
time::Duration,
};
use trust_dns_resolver::TokioAsyncResolver;
pub const COUNTER: &str = "c";
@ -57,12 +58,39 @@ impl Globals { @@ -57,12 +58,39 @@ impl Globals {
}
};
let reqwest_client = reqwest::Client::builder()
let mut reqwest_client = reqwest::Client::builder();
reqwest_client = reqwest_client
.connect_timeout(Duration::from_secs(30))
.timeout(Duration::from_secs(60 * 3))
.pool_max_idle_per_host(1)
.build()
.unwrap();
.pool_max_idle_per_host(1);
#[cfg(feature = "tor")]
{
use crate::database::TorFederation;
if let TorFederation::Enabled {
tor_proxy,
tor_only,
} = config.tor_federation.clone()
{
let proxy = if tor_only {
reqwest::Proxy::all(tor_proxy).unwrap()
} else {
reqwest::Proxy::custom(move |url| {
if url
.host_str()
.map_or(false, |host| host.ends_with(".onion"))
{
Some(tor_proxy.clone())
} else {
None
}
})
};
reqwest_client = reqwest_client.proxy(proxy);
}
}
let reqwest_client = reqwest_client.build().unwrap();
let jwt_decoding_key = config
.jwt_secret
@ -129,6 +157,11 @@ impl Globals { @@ -129,6 +157,11 @@ impl Globals {
self.config.allow_federation
}
#[cfg(feature = "tor")]
pub fn tor_federation_enabled(&self) -> bool {
self.config.tor_federation.enabled()
}
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
&self.dns_resolver
}

23
src/server_server.rs

@ -3,6 +3,8 @@ use get_profile_information::v1::ProfileField; @@ -3,6 +3,8 @@ use get_profile_information::v1::ProfileField;
use http::header::{HeaderValue, AUTHORIZATION, HOST};
use log::{info, warn};
use regex::Regex;
#[cfg(feature = "conduit_bin")]
use rocket::{get, post, put};
use rocket::{response::content::Json, State};
use ruma::{
api::{
@ -29,10 +31,6 @@ use std::{ @@ -29,10 +31,6 @@ use std::{
net::{IpAddr, SocketAddr},
time::{Duration, SystemTime},
};
#[cfg(feature = "conduit_bin")]
use {
rocket::{get, post, put}
};
#[tracing::instrument(skip(globals))]
pub async fn send_request<T: OutgoingRequest>(
@ -231,7 +229,15 @@ async fn find_actual_destination( @@ -231,7 +229,15 @@ async fn find_actual_destination(
let mut host = None;
let destination_str = destination.as_str().to_owned();
let actual_destination = "https://".to_owned()
#[cfg(not(feature = "tor"))]
let protocol = "https://";
#[cfg(feature = "tor")]
let protocol = if globals.tor_federation_enabled() && destination_str.ends_with(".onion") {
"http://"
} else {
"https://"
};
let actual_destination = protocol.to_owned()
+ &match get_ip_with_port(destination_str.clone()) {
Some(host_port) => {
// 1: IP literal with provided or default port
@ -600,9 +606,7 @@ pub async fn send_transaction_message_route<'a>( @@ -600,9 +606,7 @@ pub async fn send_transaction_message_route<'a>(
let users = namespaces
.get("users")
.and_then(|users| users.as_sequence())
.map_or_else(
Vec::new,
|users| {
.map_or_else(Vec::new, |users| {
users
.iter()
.map(|users| {
@ -613,8 +617,7 @@ pub async fn send_transaction_message_route<'a>( @@ -613,8 +617,7 @@ pub async fn send_transaction_message_route<'a>(
})
.filter_map(|o| o)
.collect::<Vec<_>>()
},
);
});
let aliases = namespaces
.get("aliases")
.and_then(|users| users.get("regex"))

28
src/utils.rs

@ -6,6 +6,7 @@ use sled::IVec; @@ -6,6 +6,7 @@ use sled::IVec;
use std::{
cmp,
convert::TryInto,
str::FromStr,
time::{SystemTime, UNIX_EPOCH},
};
@ -112,3 +113,30 @@ pub fn to_canonical_object<T: serde::Serialize>( @@ -112,3 +113,30 @@ pub fn to_canonical_object<T: serde::Serialize>(
))),
}
}
#[allow(dead_code)]
pub fn deserialize_from_str<
'de,
D: serde::de::Deserializer<'de>,
T: FromStr<Err = E>,
E: std::fmt::Display,
>(
deserializer: D,
) -> std::result::Result<T, D::Error> {
struct Visitor<T: FromStr<Err = E>, E>(std::marker::PhantomData<T>);
impl<'de, T: FromStr<Err = Err>, Err: std::fmt::Display> serde::de::Visitor<'de>
for Visitor<T, Err>
{
type Value = T;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "a parsable string")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
v.parse().map_err(|e| serde::de::Error::custom(e))
}
}
deserializer.deserialize_str(Visitor(std::marker::PhantomData))
}

Loading…
Cancel
Save