diff --git a/CHANGELOG.md b/CHANGELOG.md index c06a2eb..59f6fb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 `Proxy::config_snapshot()` and `Proxy::update_config()` are now synchronous. - Integration tests: config hot-reload on keep-alive connections, Prometheus `/metrics` endpoint (counters, histogram buckets, gauge, metadata). +- **`header_up`** in `reverse_proxy` blocks — set or remove headers on upstream + requests (applied after default `Host` / `X-Forwarded-*`). +- Upstream placeholders for `header_up`: `{upstream_host}`, `{request.uri}`, + `{remote_ip}` (via `process_upstream_substitution`). - Dependencies: `arc-swap`, `metrics`, `metrics-exporter-prometheus` (optional) ### Changed diff --git a/README.md b/README.md index 9c6cef9..54ec16c 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,25 @@ localhost:8080 { Timeout values support duration suffixes: `30s`, `5m`, `2h`, `1d`, or plain numbers (seconds). +Upstream request headers (`header_up`) can be set inside the `reverse_proxy` block. They are applied **after** the default `Host` and `X-Forwarded-*` headers, so explicit values override defaults: + +```caddy +localhost:8080 { + reverse_proxy https://api.example.com:443 { + connect_timeout 10s + read_timeout 30s + header_up Host {upstream_host} + header_up X-Original-Uri {request.uri} + header_up -Accept-Encoding + } +} +``` + +| Syntax | Action | +|--------|--------| +| `header_up Name value` | set/replace header on the upstream request | +| `header_up -Name` | remove header before forwarding | + #### `tls` Enable HTTPS on the frontend with TLS termination. Specify paths to the certificate chain and private key (PEM format). @@ -440,12 +459,18 @@ localhost:8080 { ### Placeholders -Use placeholders in header values: +Use placeholders in `header` and `header_up` values: - `{header.Name}` - Value of request header with that name - `{env.VAR}` - Value of environment variable - `{uuid}` - Random UUID +`header_up` also supports: + +- `{upstream_host}` - hostname:port of the `reverse_proxy` backend URL +- `{request.uri}` - path + query of the incoming client request +- `{remote_ip}` - client IP (`X-Forwarded-For` / `X-Real-IP`, else socket address) + ## Features ### Default Features diff --git a/benches/proxy_bench.rs b/benches/proxy_bench.rs index 1b2beb3..a29798b 100644 --- a/benches/proxy_bench.rs +++ b/benches/proxy_bench.rs @@ -324,6 +324,7 @@ fn bench_directive_operations(c: &mut Criterion) { to: "http://backend:9001".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, Directive::UriReplace { find: "/old".to_string(), @@ -426,6 +427,7 @@ fn create_simple_config() -> Config { to: "http://127.0.0.1:9001".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], tls: None, @@ -462,6 +464,7 @@ fn create_medium_config() -> Config { to: "http://backend:9001".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], tls: None, @@ -475,6 +478,7 @@ fn create_medium_config() -> Config { to: "http://backend:9002".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }], tls: None, }, @@ -508,6 +512,7 @@ fn create_complex_config() -> Config { to: "http://user-service:8001".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], }, @@ -522,6 +527,7 @@ fn create_complex_config() -> Config { to: "http://order-service:8002".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], }, @@ -541,6 +547,7 @@ fn create_complex_config() -> Config { to: "http://api-service:8000".to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], }, @@ -578,6 +585,7 @@ fn create_multi_site_config(count: usize) -> Config { to: format!("http://backend:{}", 9000 + i), connect_timeout: None, read_timeout: None, + header_up: vec![], }, ], tls: None, diff --git a/src/auth/headers.rs b/src/auth/headers.rs index b0ac926..e3a367c 100644 --- a/src/auth/headers.rs +++ b/src/auth/headers.rs @@ -80,6 +80,35 @@ pub fn process_header_substitution(value: &str, req: &Request) -> anyhow:: Ok(result) } +/// Process header value substitutions for upstream (`header_up`) operations. +/// +/// Supports all the placeholders of [`process_header_substitution`] plus three +/// extra placeholders that only make sense for outbound (upstream) headers: +/// +/// - `{upstream_host}` — hostname:port of the `reverse_proxy` backend +/// - `{request.uri}` — the full path + query of the incoming request +/// - `{remote_ip}` — client IP from `remote_addr` (or `X-Forwarded-For` / `X-Real-IP`) +/// +/// The order of substitution is: base placeholders first (`{header.*}`, `{env.*}`, +/// `{uuid}`), then the upstream-specific ones. This lets you write e.g. +/// `header_up Host {upstream_host}` while still being able to use `{header.X-Foo}`. +pub fn process_upstream_substitution( + value: &str, + req: &Request, + upstream_host: &str, + request_uri: &str, + remote_ip: &str, +) -> anyhow::Result { + // Base substitutions ({header.*}, {env.*}, {uuid}). + let mut result = process_header_substitution(value, req)?; + + result = result.replace("{upstream_host}", upstream_host); + result = result.replace("{request.uri}", request_uri); + result = result.replace("{remote_ip}", remote_ip); + + Ok(result) +} + /// Extract remote IP address from request headers /// /// Looks for the X-Forwarded-For or X-Real-IP headers to determine the @@ -188,4 +217,21 @@ mod tests { let ip = extract_remote_ip(&req); assert!(ip.is_none()); } + + #[test] + fn test_process_upstream_substitution() { + let req = make_request_with_header("X-Trace", "abc"); + let result = process_upstream_substitution( + "host={upstream_host} uri={request.uri} ip={remote_ip} trace={header.X-Trace}", + &req, + "api.example.com:443", + "/v1/items?limit=10", + "203.0.113.7", + ) + .unwrap(); + assert_eq!( + result, + "host=api.example.com:443 uri=/v1/items?limit=10 ip=203.0.113.7 trace=abc" + ); + } } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index ec156bc..2973e9b 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -25,5 +25,5 @@ pub mod headers; pub mod validator; // Re-export commonly used functions for convenience -pub use headers::process_header_substitution; +pub use headers::{process_header_substitution, process_upstream_substitution}; pub use validator::validate_token; diff --git a/src/config/mod.rs b/src/config/mod.rs index 11339b4..884934b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,4 +3,4 @@ mod models; mod parser; pub use address::{extract_hostname, resolve_listen_addr, tls_redirect_port}; -pub use models::{Config, Directive, SiteConfig, TlsConfig}; +pub use models::{Config, Directive, HeaderDirective, SiteConfig, TlsConfig}; diff --git a/src/config/models.rs b/src/config/models.rs index 17c693d..0a408c4 100644 --- a/src/config/models.rs +++ b/src/config/models.rs @@ -34,6 +34,8 @@ pub enum Directive { to: String, connect_timeout: Option, read_timeout: Option, + #[cfg_attr(feature = "api", serde(default))] + header_up: Vec, }, HandlePath { pattern: String, @@ -63,3 +65,12 @@ pub enum Directive { body: String, }, } + +/// A single header operation within a `header_up` block. +/// `value = None` means the header should be removed. +#[derive(Debug, Clone)] +#[cfg_attr(feature = "api", derive(Serialize, Deserialize))] +pub struct HeaderDirective { + pub name: String, + pub value: Option, +} diff --git a/src/config/parser.rs b/src/config/parser.rs index 3d25531..94b7687 100644 --- a/src/config/parser.rs +++ b/src/config/parser.rs @@ -1,4 +1,5 @@ use crate::config::address::{extract_hostname, resolve_listen_addr}; +use crate::config::models::HeaderDirective; use crate::config::{Config, Directive, SiteConfig}; use crate::error::ProxyError; use std::collections::HashMap; @@ -12,6 +13,8 @@ struct PendingBlock { // Timeout settings for reverse_proxy blocks (in seconds) connect_timeout: Option, read_timeout: Option, + // header_up operations collected inside a reverse_proxy block + header_up: Vec, } /// Parse a human-readable duration string into seconds. @@ -107,6 +110,7 @@ impl FromStr for Config { args, connect_timeout: None, read_timeout: None, + header_up: vec![], }); directive_stack.push(vec![]); continue; @@ -138,6 +142,7 @@ impl FromStr for Config { to, connect_timeout: block_info.connect_timeout, read_timeout: block_info.read_timeout, + header_up: block_info.header_up, } } _ => { @@ -188,7 +193,7 @@ impl FromStr for Config { let directive_name = parts[0]; let args = parts[1..].to_vec(); - // Special handling: timeout settings inside a reverse_proxy block + // Special handling: timeout and header_up settings inside a reverse_proxy block if let Some(block) = block_stack.last_mut() { if block.directive_type == "reverse_proxy" { match directive_name { @@ -218,9 +223,44 @@ impl FromStr for Config { })?); continue; } + "header_up" => { + let raw_name = args.first().cloned().ok_or_else(|| { + ProxyError::Parse(format!( + "Missing header name for header_up on line {}", + line_num + 1 + )) + })?; + let directive = if let Some(name) = raw_name.strip_prefix('-') { + if name.is_empty() { + return Err(ProxyError::Parse(format!( + "Missing header name after '-' in header_up on line {}", + line_num + 1 + ))); + } + HeaderDirective { + name: name.to_string(), + value: None, + } + } else { + let value = args[1..].join(" "); + if value.is_empty() { + return Err(ProxyError::Parse(format!( + "Missing value for header_up {} on line {}", + raw_name, + line_num + 1 + ))); + } + HeaderDirective { + name: raw_name.to_string(), + value: Some(value), + } + }; + block.header_up.push(directive); + continue; + } _ => { return Err(ProxyError::Parse(format!( - "Unexpected directive '{}' inside reverse_proxy block on line {}. Only connect_timeout and read_timeout are allowed.", + "Unexpected directive '{}' inside reverse_proxy block on line {}. Allowed: connect_timeout, read_timeout, header_up.", directive_name, line_num + 1 ))); } @@ -265,6 +305,7 @@ impl FromStr for Config { to: to.to_string(), connect_timeout: None, read_timeout: None, + header_up: vec![], } } "uri_replace" => { @@ -447,6 +488,7 @@ mod tests { to, connect_timeout, read_timeout, + .. } => { assert_eq!(to, "http://backend:9001"); assert_eq!(*connect_timeout, None); @@ -473,6 +515,7 @@ mod tests { to, connect_timeout, read_timeout, + .. } => { assert_eq!(to, "http://backend:9001"); assert_eq!(*connect_timeout, Some(10)); @@ -505,6 +548,33 @@ mod tests { } } + #[test] + fn test_parse_reverse_proxy_with_header_up() { + let config = r#"localhost:8080 { + reverse_proxy https://api.example.com:443 { + header_up Host {upstream_host} + header_up X-Original-Uri {request.uri} + header_up -Accept-Encoding + } +}"#; + let result: Config = config.parse().unwrap(); + let site = result.sites.get("localhost:8080").unwrap(); + + match &site.directives[0] { + Directive::ReverseProxy { to, header_up, .. } => { + assert_eq!(to, "https://api.example.com:443"); + assert_eq!(header_up.len(), 3); + assert_eq!(header_up[0].name, "Host"); + assert_eq!(header_up[0].value.as_deref(), Some("{upstream_host}")); + assert_eq!(header_up[1].name, "X-Original-Uri"); + assert_eq!(header_up[1].value.as_deref(), Some("{request.uri}")); + assert_eq!(header_up[2].name, "Accept-Encoding"); + assert!(header_up[2].value.is_none()); + } + _ => panic!("Expected ReverseProxy directive"), + } + } + #[test] fn test_parse_reverse_proxy_block_rejects_unknown_directive() { let config = r#"localhost:8080 { diff --git a/src/proxy/directives.rs b/src/proxy/directives.rs index 77c7a0f..c48ce09 100644 --- a/src/proxy/directives.rs +++ b/src/proxy/directives.rs @@ -2,7 +2,8 @@ use hyper::body::Incoming; use hyper::Request; use tracing::info; -use crate::auth::process_header_substitution; +use crate::auth::{process_header_substitution, process_upstream_substitution}; +use crate::config::HeaderDirective; use crate::proxy::ActionResult; @@ -12,16 +13,21 @@ pub fn handle_reverse_proxy( path: &str, connect_timeout: Option, read_timeout: Option, + header_up: Vec, ) -> ActionResult { info!( - " Proxying to: {} (connect_timeout: {:?}, read_timeout: {:?})", - to, connect_timeout, read_timeout + " Proxying to: {} (connect_timeout: {:?}, read_timeout: {:?}, header_up: {} ops)", + to, + connect_timeout, + read_timeout, + header_up.len() ); ActionResult::ReverseProxy { backend_url: to.to_string(), path_to_send: path.to_string(), connect_timeout, read_timeout, + header_up, } } @@ -75,6 +81,61 @@ pub fn handle_header( Ok(()) } +/// Apply `header_up` directives to the outbound (upstream) request. +/// +/// Runs after default Host / X-Forwarded-* headers so explicit `header_up` can override them. +pub fn apply_header_up( + directives: &[HeaderDirective], + req: &mut Request, + upstream_host: &str, + request_uri: &str, + remote_ip: &str, +) { + use hyper::header::{HeaderName, HeaderValue}; + + for directive in directives { + match HeaderName::from_bytes(directive.name.as_bytes()) { + Ok(header_name) => match &directive.value { + Some(val) => { + match process_upstream_substitution( + val, + req, + upstream_host, + request_uri, + remote_ip, + ) { + Ok(processed) => match HeaderValue::from_str(&processed) { + Ok(header_value) => { + req.headers_mut().insert(header_name, header_value); + info!(" Applied header_up: {} = {}", directive.name, processed); + } + Err(e) => { + info!( + " Failed to apply header_up {}: invalid value: {}", + directive.name, e + ); + } + }, + Err(e) => { + info!(" Failed to apply header_up {}: {}", directive.name, e); + } + } + } + None => { + req.headers_mut().remove(&header_name); + info!(" Removed header_up: {}", directive.name); + } + }, + Err(e) => { + info!( + " Failed to apply header_up {}: invalid header name: {}", + directive.name, e + ); + } + } + } +} + /// Handle uri_replace directive - replace substring in path pub fn handle_uri_replace(find: &str, replace: &str, path: &mut String) { *path = path.replace(find, replace); @@ -249,4 +310,52 @@ mod tests { _ => panic!("Expected Redirect action"), } } + + #[test] + fn test_apply_header_up_set_and_remove() { + use bytes::Bytes; + use http_body_util::Empty; + + let mut req = Request::builder() + .header("Accept-Encoding", "gzip") + .body(Empty::::new()) + .unwrap(); + + let directives = vec![ + HeaderDirective { + name: "Host".to_string(), + value: Some("{upstream_host}".to_string()), + }, + HeaderDirective { + name: "X-Original-Uri".to_string(), + value: Some("{request.uri}".to_string()), + }, + HeaderDirective { + name: "Accept-Encoding".to_string(), + value: None, + }, + ]; + + apply_header_up( + &directives, + &mut req, + "api.example.com:443", + "/api/test?q=1", + "10.0.0.1", + ); + + assert_eq!( + req.headers().get("Host").unwrap().to_str().unwrap(), + "api.example.com:443" + ); + assert_eq!( + req.headers() + .get("X-Original-Uri") + .unwrap() + .to_str() + .unwrap(), + "/api/test?q=1" + ); + assert!(req.headers().get("Accept-Encoding").is_none()); + } } diff --git a/src/proxy/handler.rs b/src/proxy/handler.rs index 98da365..f5e72bd 100644 --- a/src/proxy/handler.rs +++ b/src/proxy/handler.rs @@ -23,8 +23,8 @@ use crate::proxy::access_log::{ensure_request_id, final_request_id}; use crate::proxy::ActionResult; use crate::proxy::directives::{ - handle_header, handle_method, handle_redirect, handle_respond, handle_reverse_proxy, - handle_strip_prefix, handle_uri_replace, + apply_header_up, handle_header, handle_method, handle_redirect, handle_respond, + handle_reverse_proxy, handle_strip_prefix, handle_uri_replace, }; /// Unified response body type - can handle both streaming (`Incoming`) and buffered (`Full`) @@ -110,12 +110,14 @@ pub fn process_directives( to, connect_timeout, read_timeout, + header_up, } => { return Ok(handle_reverse_proxy( to, &modified_path, *connect_timeout, *read_timeout, + header_up.clone(), )); } } @@ -272,6 +274,7 @@ pub async fn proxy( path_to_send, connect_timeout: _, read_timeout, + header_up, } => { // Add protocol if missing let backend_with_proto = @@ -286,8 +289,26 @@ pub async fn proxy( parts.path_and_query = Some(path_to_send.parse()?); let new_uri = Uri::from_parts(parts)?; + // Capture the original request URI (path + query) before we overwrite it — + // needed for the {request.uri} placeholder in header_up. + let original_request_uri = req + .uri() + .path_and_query() + .map(|pq| pq.as_str().to_string()) + .unwrap_or_default(); + *req.uri_mut() = new_uri.clone(); + // upstream_host is the authority of the backend URL, used for {upstream_host}. + let upstream_host = new_uri + .authority() + .map(|a| a.as_str().to_string()) + .unwrap_or_default(); + + // remote_ip: prefer X-Forwarded-For / X-Real-IP, fall back to the socket peer. + let remote_ip = crate::auth::headers::extract_remote_ip(&req) + .unwrap_or_else(|| remote_addr.ip().to_string()); + // Save original host for X-Forwarded headers let original_host_header = req.headers().get(hyper::header::HOST).cloned(); @@ -322,6 +343,14 @@ pub async fn proxy( req.headers_mut().remove(header::CONNECTION); req.headers_mut().remove("accept-encoding"); + apply_header_up( + &header_up, + &mut req, + &upstream_host, + &original_request_uri, + &remote_ip, + ); + // Forward request to backend with configurable timeout (default 30s) let backend_timeout = read_timeout.unwrap_or(30); match timeout(Duration::from_secs(backend_timeout), client.request(req)).await { diff --git a/src/proxy/types.rs b/src/proxy/types.rs index 3ae5e4b..13919a1 100644 --- a/src/proxy/types.rs +++ b/src/proxy/types.rs @@ -1,3 +1,5 @@ +use crate::config::HeaderDirective; + /// Result of directive processing #[derive(Debug, Clone)] pub enum ActionResult { @@ -10,6 +12,7 @@ pub enum ActionResult { path_to_send: String, connect_timeout: Option, read_timeout: Option, + header_up: Vec, }, Redirect { status: u16, diff --git a/tests/header_up_integration.rs b/tests/header_up_integration.rs new file mode 100644 index 0000000..1b6b177 --- /dev/null +++ b/tests/header_up_integration.rs @@ -0,0 +1,123 @@ +//! Integration tests for `header_up` on upstream requests. + +use std::collections::HashMap; +use std::convert::Infallible; + +use bytes::Bytes; +use http_body_util::{BodyExt, Full}; +use hyper::service::service_fn; +use hyper::{Request, Response}; +use hyper_util::rt::TokioIo; +use tiny_proxy::config::{Directive, HeaderDirective, SiteConfig}; +use tiny_proxy::{Config, Proxy}; +use tokio::net::TcpListener; + +async fn get_random_port_addr() -> std::net::SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + drop(listener); + addr +} + +/// Backend that echoes the `Host` and `X-Original-Uri` request headers. +async fn echo_upstream_headers( + req: Request, +) -> Result>, Infallible> { + let host = req + .headers() + .get("host") + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + let uri = req + .headers() + .get("x-original-uri") + .and_then(|h| h.to_str().ok()) + .unwrap_or(""); + let has_accept_encoding = req.headers().contains_key("accept-encoding"); + + let body = format!("host={host}|uri={uri}|ae={has_accept_encoding}"); + Ok(Response::new(Full::new(Bytes::from(body)))) +} + +async fn start_echo_backend() -> std::net::SocketAddr { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + tokio::spawn(async move { + loop { + let (stream, _) = match listener.accept().await { + Ok(v) => v, + Err(_) => continue, + }; + let io = TokioIo::new(stream); + tokio::spawn(async move { + let service = service_fn(echo_upstream_headers); + let _ = hyper::server::conn::http1::Builder::new() + .serve_connection(io, service) + .await; + }); + } + }); + + addr +} + +fn proxy_config(proxy_addr: std::net::SocketAddr, backend_addr: std::net::SocketAddr) -> Config { + let host = format!("127.0.0.1:{}", proxy_addr.port()); + let mut sites = HashMap::new(); + sites.insert( + host.clone(), + SiteConfig { + address: host, + directives: vec![Directive::ReverseProxy { + to: format!("http://{}", backend_addr), + connect_timeout: None, + read_timeout: None, + header_up: vec![ + HeaderDirective { + name: "Host".to_string(), + value: Some("api.example.com".to_string()), + }, + HeaderDirective { + name: "X-Original-Uri".to_string(), + value: Some("{request.uri}".to_string()), + }, + HeaderDirective { + name: "Accept-Encoding".to_string(), + value: None, + }, + ], + }], + tls: None, + }, + ); + Config { sites } +} + +#[tokio::test] +async fn test_header_up_reaches_backend() { + let backend_addr = start_echo_backend().await; + let proxy_addr = get_random_port_addr().await; + let proxy_host = format!("127.0.0.1:{}", proxy_addr.port()); + + let proxy = Proxy::new(proxy_config(proxy_addr, backend_addr)); + tokio::spawn(async move { + proxy.start_with_addr(proxy_addr).await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let client = hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new()) + .build::<_, Full>(hyper_util::client::legacy::connect::HttpConnector::new()); + + let uri = format!("http://{proxy_host}/items?limit=3") + .parse() + .unwrap(); + let response = client.get(uri).await.expect("proxy request should succeed"); + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!( + std::str::from_utf8(&body).unwrap(), + "host=api.example.com|uri=/items?limit=3|ae=false" + ); +} diff --git a/tests/tls_integration.rs b/tests/tls_integration.rs index 7490fed..aa8041a 100644 --- a/tests/tls_integration.rs +++ b/tests/tls_integration.rs @@ -192,6 +192,7 @@ async fn test_tls_x_forwarded_proto_https() { to: format!("http://127.0.0.1:{}", backend_addr.port()), connect_timeout: None, read_timeout: None, + header_up: vec![], }], tls: Some(tiny_proxy::config::TlsConfig { cert_path: cert_file.path().to_str().unwrap().to_string(),