diff --git a/Cargo.lock b/Cargo.lock index 75cc3b09..5b95595c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -545,6 +545,12 @@ version = "0.15.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1aaf95b3e5c8f23aa320147307562d361db0ae0d51242340f558153b4eb2439b" +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1427,6 +1433,7 @@ dependencies = [ "tegen", "time", "tokio", + "tokio-socks", "toml", "url", "uuid", @@ -2089,6 +2096,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-socks" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d4770b8024672c1101b3f6733eab95b18007dbe0847a8afe341fcf79e06043f" +dependencies = [ + "either", + "futures-util", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.13" diff --git a/Cargo.toml b/Cargo.toml index 1e0cb0cc..d51e635a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,9 +59,9 @@ htmlescape = "0.3.1" bincode = "1.3.3" base2048 = "2.0.2" revision = "0.10.0" +tokio-socks = "0.5.2" fake_user_agent = "0.2.2" - [dev-dependencies] lipsum = "0.9.0" sealed_test = "1.0.0" diff --git a/src/client.rs b/src/client.rs index 7499c8dc..f858d055 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,7 +2,6 @@ use arc_swap::ArcSwap; use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; -use hyper::client::HttpConnector; use hyper::header::HeaderValue; use hyper::{body, body::Buf, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; @@ -30,10 +29,17 @@ const REDDIT_SHORT_URL_BASE_HOST: &str = "redd.it"; const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; const ALTERNATIVE_REDDIT_URL_BASE_HOST: &str = "www.reddit.com"; -pub static HTTPS_CONNECTOR: LazyLock> = - LazyLock::new(|| hyper_rustls::HttpsConnectorBuilder::new().with_native_roots().https_only().enable_http2().build()); -pub static CLIENT: LazyLock>> = LazyLock::new(|| Client::builder().build::<_, Body>(HTTPS_CONNECTOR.clone())); +pub static HTTPS_CONNECTOR: LazyLock> = LazyLock::new(|| { + let proxy_connector = ProxyConnector::new(); + hyper_rustls::HttpsConnectorBuilder::new() + .with_native_roots() + .https_only() + .enable_http2() + .wrap_connector(proxy_connector) +}); + +pub static CLIENT: LazyLock>> = LazyLock::new(|| Client::builder().build::<_, Body>(HTTPS_CONNECTOR.clone())); pub static OAUTH_CLIENT: LazyLock> = LazyLock::new(|| { let client = block_on(Oauth::new()); @@ -521,6 +527,7 @@ pub async fn rate_limit_check() -> Result<(), String> { #[cfg(test)] use {crate::config::get_setting, sealed_test::prelude::*}; +use crate::proxy::ProxyConnector; #[tokio::test(flavor = "multi_thread")] async fn test_rate_limit_check() { diff --git a/src/lib.rs b/src/lib.rs index b8eb17e7..45425307 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,3 +11,4 @@ pub mod settings; pub mod subreddit; pub mod user; pub mod utils; +mod proxy; diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 00000000..4d242bda --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,177 @@ +use base64::engine::general_purpose; +use base64::Engine; +use hyper::client::HttpConnector; +use hyper::service::Service; +use hyper::Uri; +use log::debug; +use std::env; +use std::error::Error; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tokio::net::TcpStream; +use tokio_socks::tcp::Socks5Stream; + +type BoxError = Box; +type BoxFuture = Pin> + Send>>; +type Credentials = (String, String); + +#[derive(Clone)] +pub enum ProxyConnector { + NoProxy(HttpConnector), + Socks(String), + Http(String), +} + +#[derive(Debug)] +pub struct ProxyError(String); + +impl fmt::Display for ProxyError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Proxy error: {}", self.0) + } +} + +impl Error for ProxyError {} + +impl Service for ProxyConnector { + type Response = TcpStream; + type Error = BoxError; + type Future = BoxFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + match self { + ProxyConnector::NoProxy(connector) => connector.poll_ready(cx).map_err(Into::into), + _ => Poll::Ready(Ok(())), + } + } + + fn call(&mut self, uri: Uri) -> Self::Future { + let this = self.clone(); + Box::pin(async move { + match this { + ProxyConnector::NoProxy(mut connector) => { + connector.call(uri).await.map_err(Into::into) + } + ProxyConnector::Socks(proxy_addr) => handle_socks_connection(&proxy_addr, &uri).await, + ProxyConnector::Http(proxy_addr) => handle_http_connection(&proxy_addr, &uri).await, + } + }) + } +} + +impl ProxyConnector { + pub fn new() -> Self { + if let Ok(socks_proxy) = env::var("SOCKS_PROXY") { + debug!("Using SOCKS proxy: {}", socks_proxy); + return ProxyConnector::Socks(socks_proxy); + } + + if let Ok(http_proxy) = env::var("HTTP_PROXY").or_else(|_| env::var("HTTPS_PROXY")) { + debug!("Using HTTP proxy: {}", http_proxy); + return ProxyConnector::Http(http_proxy); + } + + let mut connector = HttpConnector::new(); + connector.enforce_http(false); + ProxyConnector::NoProxy(connector) + } +} + +async fn handle_socks_connection(proxy_addr: &str, uri: &Uri) -> Result { + let (host, port, credentials) = parse_proxy_addr(proxy_addr)?; + let target_addr = get_target_addr(uri)?; + + let stream = match credentials { + Some((username, password)) => { + Socks5Stream::connect_with_password((host.as_str(), port), target_addr, &username, &password).await + } + None => Socks5Stream::connect((host.as_str(), port), target_addr).await, + }?; + + Ok(stream.into_inner()) +} + +async fn handle_http_connection(proxy_addr: &str, uri: &Uri) -> Result { + let (host, port, credentials) = parse_proxy_addr(proxy_addr)?; + let proxy_stream = TcpStream::connect((host.as_str(), port)).await?; + let target_addr = get_target_addr(uri)?; + + let connect_req = build_connect_request(&target_addr, credentials)?; + write_and_verify_connection(&proxy_stream, &connect_req).await?; + + Ok(proxy_stream) +} + +fn build_connect_request(target_addr: &str, credentials: Option) -> Result { + let mut req = format!( + "CONNECT {target_addr} HTTP/1.1\r\n\ + Host: {target_addr}\r\n\ + Connection: keep-alive\r\n" + ); + + if let Some((username, password)) = credentials { + let auth = general_purpose::STANDARD.encode(format!("{}:{}", username, password)); + req.push_str(&format!("Proxy-Authorization: Basic {}\r\n", auth)); + } + + req.push_str("\r\n"); + Ok(req) +} + +async fn write_and_verify_connection(proxy_stream: &TcpStream, connect_req: &str) -> Result<(), BoxError> { + proxy_stream.writable().await?; + proxy_stream.try_write(connect_req.as_bytes())?; + + let mut response = [0u8; 1024]; + proxy_stream.readable().await?; + let n = proxy_stream.try_read(&mut response)?; + + let response = String::from_utf8_lossy(&response[..n]); + if !response.starts_with("HTTP/1.1 200") { + return Err(Box::new(ProxyError(format!("Proxy CONNECT failed: {}", response)))); + } + + Ok(()) +} + +fn parse_proxy_addr(addr: &str) -> Result<(String, u16, Option), BoxError> { + let uri: Uri = addr.parse()?; + let host = uri.host().ok_or("Missing proxy host")?.to_string(); + let port = uri.port_u16().unwrap_or_else(|| { + if uri.scheme_str() == Some("https") { 443 } else { 80 } + }); + + let credentials = extract_credentials(uri.authority())?; + Ok((host, port, credentials)) +} + +fn extract_credentials(authority: Option<&hyper::http::uri::Authority>) -> Result, BoxError> { + let Some(authority) = authority else { + return Ok(None); + }; + + let Some(credentials) = authority.as_str().split('@').next() else { + return Ok(None); + }; + + if credentials == authority.as_str() { + return Ok(None); + } + + let creds: Vec<&str> = credentials.split(':').collect(); + if creds.len() == 2 { + Ok(Some((creds[0].to_string(), creds[1].to_string()))) + } else { + Ok(None) + } +} + +fn get_target_addr(uri: &Uri) -> Result { + let host = uri.host().ok_or("Missing target host")?; + let port = uri.port_u16().unwrap_or_else(|| { + if uri.scheme_str() == Some("https") { 443 } else { 80 } + }); + Ok(format!("{}:{}", host, port)) +} \ No newline at end of file