Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dist/efs-utils.conf
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ optimize_readahead = true
# By default, we enable the feature to fallback to mount with mount target ip address when dns name cannot be resolved
fall_back_to_mount_target_ip_address_enabled = true

# When enabled, efs-utils calls DescribeMountTargets to determine the mount target address family (IPv4/IPv6)
# and constrains DNS resolution and efs-proxy connections to that family. Disable to skip the API call
# and let the OS resolver decide (equivalent to the pre-3.x behavior).
# Requires IAM actions: elasticfilesystem:DescribeMountTargets, ec2:DescribeAvailabilityZones.
# If those actions are unavailable, efs-utils falls back to AF_UNSPEC automatically.
dynamic_address_family_enabled = true

# By default, we use IMDSv2 to get the instance metadata, set this to true if you want to disable IMDSv2 usage
disable_fetch_ec2_metadata_token = false

Expand Down
2 changes: 2 additions & 0 deletions src/efs_utils_common/mount_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def mount_with_proxy(
mountpoint,
options,
fallback_ip_address=None,
address_family=None,
):
"""
This function is responsible for launching a efs-proxy process and attaching a NFS mount to that process
Expand All @@ -291,6 +292,7 @@ def mount_with_proxy(
options,
fallback_ip_address=fallback_ip_address,
efs_proxy_enabled=efs_proxy_enabled,
address_family=address_family,
) as tunnel_proc:
mount_completed = threading.Event()
t = threading.Thread(
Expand Down
4 changes: 2 additions & 2 deletions src/efs_utils_common/network_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def get_ipv6_addresses(hostname):
return []


def dns_name_can_be_resolved(dns_name):
def dns_name_can_be_resolved(dns_name, family=socket.AF_UNSPEC):
try:
addr_info = socket.getaddrinfo(dns_name, None, socket.AF_UNSPEC)
addr_info = socket.getaddrinfo(dns_name, None, family)
return len(addr_info) > 0
except socket.gaierror:
return False
Expand Down
9 changes: 9 additions & 0 deletions src/efs_utils_common/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ def write_stunnel_config_file(
cert_details=None,
fallback_ip_address=None,
efs_proxy_enabled=True,
address_family=None,
):
"""
Serializes stunnel configuration to a file. Unfortunately this does not conform to Python's config file format, so we have to
Expand Down Expand Up @@ -424,6 +425,12 @@ def write_stunnel_config_file(
efs_config["fs_id"] = fs_id
efs_config["region"] = region
efs_config["efs_utils_version"] = VERSION
if address_family is not None:
if address_family == socket.AF_INET6:
efs_config["address_family"] = "ipv6"
elif address_family == socket.AF_INET:
efs_config["address_family"] = "ipv4"
# AF_UNSPEC: omit the key, proxy uses its own default (unspec)

stunnel_config = "\n".join(
serialize_stunnel_config(global_config)
Expand Down Expand Up @@ -693,6 +700,7 @@ def bootstrap_proxy(
state_file_dir=STATE_FILE_DIR,
fallback_ip_address=None,
efs_proxy_enabled=True,
address_family=None,
):
"""
Generates a TLS private key and client-side certificate, a stunnel configuration file, and a state file
Expand Down Expand Up @@ -809,6 +817,7 @@ def bootstrap_proxy(
cert_details=cert_details,
fallback_ip_address=fallback_ip_address,
efs_proxy_enabled=efs_proxy_enabled,
address_family=address_family,
)
if efs_proxy_enabled:
tunnel_args = [_efs_proxy_bin(), stunnel_config_file]
Expand Down
10 changes: 9 additions & 1 deletion src/mount_efs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import logging
import platform
import re
import socket
import sys

from efs_utils_common.cloudwatch import bootstrap_cloudwatch_logging
Expand Down Expand Up @@ -69,6 +70,7 @@
from efs_utils_common.proxy import get_init_system
from mount_efs.dns_resolver import (
get_dns_name_and_fallback_mount_target_ip_address,
get_mount_target_address_family,
match_device,
)

Expand Down Expand Up @@ -178,8 +180,13 @@ def main():
init_system = get_init_system()
check_network_status(fs_id, init_system)

address_family = (
get_mount_target_address_family(config, options, fs_id)
if "mounttargetip" not in options
else socket.AF_UNSPEC
)
dns_name, fallback_ip_address = get_dns_name_and_fallback_mount_target_ip_address(
config, fs_id, options
config, fs_id, options, address_family=address_family
)

if check_if_platform_is_mac() and "notls" not in options:
Expand Down Expand Up @@ -208,6 +215,7 @@ def main():
mountpoint,
options,
fallback_ip_address=fallback_ip_address,
address_family=address_family,
)


Expand Down
31 changes: 29 additions & 2 deletions src/mount_efs/dns_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_target_az(config, options):
return None


def get_dns_name_and_fallback_mount_target_ip_address(config, fs_id, options):
def get_dns_name_and_fallback_mount_target_ip_address(config, fs_id, options, address_family=socket.AF_UNSPEC):
def _validate_replacement_field_count(format_str, expected_ct):
if format_str.count("{") != expected_ct or format_str.count("}") != expected_ct:
raise ValueError(
Expand Down Expand Up @@ -124,7 +124,7 @@ def _validate_replacement_field_count(format_str, expected_ct):
ip_address=ip_address, fallback_message=fallback_message
)

if dns_name_can_be_resolved(dns_name):
if dns_name_can_be_resolved(dns_name, family=address_family):
return dns_name, None

logging.info(
Expand Down Expand Up @@ -179,6 +179,33 @@ def get_fallback_mount_target_ip_address(config, options, fs_id, dns_name):
)


def get_mount_target_address_family(config, options, fs_id):
"""Return socket.AF_INET, socket.AF_INET6, or socket.AF_UNSPEC based on the actual mount target IP type.
Falls back to socket.AF_UNSPEC if the API call fails or the feature is disabled."""
if not get_boolean_config_item_value(config, CONFIG_SECTION, "dynamic_address_family_enabled", default_value=True):
return socket.AF_UNSPEC

try:
efs_client = get_botocore_client(config, "efs", options)
if efs_client is None:
return socket.AF_UNSPEC

az_name = get_target_az(config, options)
ec2_client = get_botocore_client(config, "ec2", options)
if ec2_client is None:
return socket.AF_UNSPEC

mount_target = get_mount_target_in_az(efs_client, ec2_client, fs_id, az_name)
if "Ipv6Address" in mount_target and "IpAddress" not in mount_target:
return socket.AF_INET6

return socket.AF_INET
except Exception:
logging.info("Failed to determine mount target address family, defaulting to AF_UNSPEC")
logging.debug("get_mount_target_address_family exception detail", exc_info=True)
return socket.AF_UNSPEC


def check_if_fall_back_to_mount_target_ip_address_is_enabled(config):
return get_boolean_config_item_value(
config,
Expand Down
8 changes: 8 additions & 0 deletions src/proxy/src/config_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ fn default_log_format() -> Option<String> {
Some("file".to_string())
}

fn default_address_family() -> String {
"unspec".to_string()
}

#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ProxyConfig {
#[serde(alias = "fips", deserialize_with = "deserialize_bool")]
Expand Down Expand Up @@ -188,6 +192,10 @@ pub struct EfsConfig {
/// efs-utils version string for channel init
#[serde(alias = "efs_utils_version", default)]
pub efs_utils_version: String,

/// Address family to use when resolving the mount target hostname ("ipv4" or "ipv6")
#[serde(default = "default_address_family")]
pub address_family: String,
}
#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)]
pub struct ReadBypassConfig {
Expand Down
84 changes: 78 additions & 6 deletions src/proxy/src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ use futures::future::{self, BoxFuture};
use log::{debug, info, warn};
use s2n_tls_tokio::TlsStream;
use std::sync::Arc;
use std::{collections::HashMap, time::Duration};
use std::{collections::HashMap, net::SocketAddr, time::Duration};
use tokio::task::JoinHandle;
use tokio::time::{timeout_at, Instant};
use tokio::{
io::AsyncWriteExt,
io::{AsyncRead, AsyncWrite},
net::lookup_host,
net::TcpStream,
sync::mpsc,
};
Expand Down Expand Up @@ -334,14 +335,22 @@ pub fn get_bind_response_string(bind_response: &BindResponse) -> String {
#[derive(Clone)]
pub struct PlainTextPartitionFinder {
pub mount_target_addr: String,
pub address_family: String,
}

#[async_trait]
impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {
async fn create_connect_future(&self) -> BoxFuture<'static, Result<TcpStream, ConnectError>> {
let mount_target_address = self.mount_target_addr.clone();
let address_family = self.address_family.clone();
Box::pin(async move {
match TcpStream::connect(mount_target_address).await {
let stream = if address_family == "unspec" {
TcpStream::connect(&mount_target_address).await
} else {
let addr = resolve_addr(&mount_target_address, &address_family).await?;
TcpStream::connect(addr).await
};
match stream {
Ok(tcp_stream) => Ok(configure_stream(tcp_stream)),
Err(e) => Err(ConnectError::IoError(e)),
}
Expand All @@ -351,11 +360,15 @@ impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {

pub struct TlsPartitionFinder {
tls_config: Arc<tokio::sync::Mutex<TlsConfig>>,
pub address_family: String,
}

impl TlsPartitionFinder {
pub fn new(tls_config: Arc<tokio::sync::Mutex<TlsConfig>>) -> Self {
TlsPartitionFinder { tls_config }
pub fn new(tls_config: Arc<tokio::sync::Mutex<TlsConfig>>, address_family: String) -> Self {
TlsPartitionFinder {
tls_config,
address_family,
}
}
}

Expand All @@ -364,11 +377,40 @@ impl PartitionFinder<TlsStream<TcpStream>> for TlsPartitionFinder {
async fn create_connect_future(
&self,
) -> BoxFuture<'static, Result<TlsStream<TcpStream>, ConnectError>> {
let tls_config_copy = self.tls_config.lock().await.clone();
Box::pin(establish_tls_stream(tls_config_copy))
let mut tls_config_copy = self.tls_config.lock().await.clone();
let address_family = self.address_family.clone();
Box::pin(async move {
if address_family != "unspec" {
let addr = resolve_addr(&tls_config_copy.remote_addr, &address_family).await?;
tls_config_copy.remote_addr = addr.to_string();
}
establish_tls_stream(tls_config_copy).await
})
}
}

/// Resolve `addr` (host:port) to a `SocketAddr` matching the requested address family.
/// `address_family`: `"ipv4"` selects the first IPv4 result, `"ipv6"` selects the first IPv6 result.
/// Returns an error if no address of the requested family is found.
/// Callers must not pass `"unspec"` — use `TcpStream::connect(hostname)` directly instead.
async fn resolve_addr(addr: &str, address_family: &str) -> Result<SocketAddr, ConnectError> {
let addrs: Vec<SocketAddr> = lookup_host(addr)
.await
.map_err(ConnectError::IoError)?
.collect();
let result = match address_family {
"ipv4" => addrs.iter().find(|a| a.is_ipv4()).copied(),
"ipv6" => addrs.iter().find(|a| a.is_ipv6()).copied(),
_ => None,
};
result.ok_or_else(|| {
ConnectError::IoError(std::io::Error::other(format!(
"no {} address found for {}",
address_family, addr
)))
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -399,6 +441,7 @@ mod tests {
let error = tokio::spawn(async move {
let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
address_family: "unspec".to_string(),
};
partition_finder
.establish_connection(create_deadline(test_single_connection_timeout), PROXY_ID)
Expand All @@ -421,6 +464,7 @@ mod tests {

let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
address_family: "unspec".to_string(),
};
partition_finder
.inner_establish_multiplex_connection(
Expand Down Expand Up @@ -453,6 +497,7 @@ mod tests {
let task = tokio::spawn(async move {
let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
address_family: "unspec".to_string(),
};
partition_finder
.inner_establish_multiplex_connection(
Expand Down Expand Up @@ -565,6 +610,7 @@ mod tests {
});
let tls_partition_finder = TlsPartitionFinder {
tls_config: tls_config_ptr.clone(),
address_family: "unspec".to_string(),
};
let _ = kill(nix::unistd::Pid::this(), Signal::SIGHUP);
rx.await.unwrap();
Expand All @@ -573,4 +619,30 @@ mod tests {
tls_partition_finder.tls_config.lock().await.client_cert
);
}

#[tokio::test]
async fn test_resolve_addr_ipv4_success() {
// "127.0.0.1:2049" is a numeric address — lookup_host returns it without DNS.
let result = resolve_addr("127.0.0.1:2049", "ipv4").await;
assert!(result.is_ok());
assert!(result.unwrap().is_ipv4());
}

#[tokio::test]
async fn test_resolve_addr_ipv6_success() {
// "[::1]:2049" is a numeric IPv6 address — lookup_host returns it without DNS.
let result = resolve_addr("[::1]:2049", "ipv6").await;
assert!(result.is_ok());
assert!(result.unwrap().is_ipv6());
}

#[tokio::test]
async fn test_resolve_addr_ipv6_only_with_ipv4_requested_returns_error() {
// "[::1]:2049" is a numeric IPv6 address — guaranteed to resolve to only an IPv6 SocketAddr.
let result = resolve_addr("[::1]:2049", "ipv4").await;
assert!(
matches!(result, Err(ConnectError::IoError(_))),
"expected error when no IPv4 address available"
);
}
}
6 changes: 5 additions & 1 deletion src/proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ async fn main() {
let controller = Controller::new(
&proxy_config.nested_config.listen_addr,
proxy_config.clone(),
Arc::new(TlsPartitionFinder::new(tls_config)),
Arc::new(TlsPartitionFinder::new(
tls_config,
proxy_config.nested_config.address_family.clone(),
)),
status_reporter,
cw_publisher.clone(),
)
Expand All @@ -112,6 +115,7 @@ async fn main() {
proxy_config.clone(),
Arc::new(PlainTextPartitionFinder {
mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(),
address_family: proxy_config.nested_config.address_family.clone(),
}),
status_reporter,
cw_publisher.clone(),
Expand Down
Loading
Loading