Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/mount_efs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,11 @@ def write_stunnel_config_file(
system_release_version = get_system_release_version()
global_config = dict(STUNNEL_GLOBAL_CONFIG)

if (config.has_option(CONFIG_SECTION, "csi_driver_version")):
global_config["csi_driver_version"] = config.get(
CONFIG_SECTION, "csi_driver_version"
)

if not efs_proxy_enabled and is_stunnel_option_supported(
stunnel_options, b"foreground", b"quiet", emit_warning_log=False
):
Expand Down
6 changes: 6 additions & 0 deletions src/proxy/src/config_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ pub struct ProxyConfig {
/// This nested structure is required for backwards compatibility
#[serde(alias = "efs")]
pub nested_config: EfsConfig,

#[serde(alias = "csi_driver_version")]
pub csi_driver_version: Option<String>,
}

impl FromStr for ProxyConfig {
Expand Down Expand Up @@ -109,6 +112,7 @@ output = /var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log
pid = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid
port = 8081
initial_partition_ip = 127.0.0.1:2049
csi_driver_version = v9.9.9

[efs]
accept = 127.0.0.1:21036
Expand Down Expand Up @@ -136,6 +140,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com
output: Some(String::from(
"/var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log",
)),
csi_driver_version: Some("v9.9.9".to_string()),
nested_config: EfsConfig {
listen_addr: String::from("127.0.0.1:21036"),
mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"),
Expand Down Expand Up @@ -187,6 +192,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com
),
debug: DEFAULT_LOG_LEVEL.to_string(),
output: None,
csi_driver_version: None,
nested_config: EfsConfig {
listen_addr: String::from("127.0.0.1:21036"),
mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"),
Expand Down
20 changes: 13 additions & 7 deletions src/proxy/src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub trait PartitionFinder<S: ProxyStream> {
async fn establish_connection(
&self,
proxy_id: ProxyIdentifier,
csi_driver_version: Option<String>,
) -> Result<(S, Option<PartitionId>, Option<ScaleUpConfig>), ConnectError>;

async fn spawn_establish_connection_task(
Expand Down Expand Up @@ -269,10 +270,11 @@ impl PlainTextPartitionFinder {
async fn establish_plain_text_connection(
mount_target_addr: String,
proxy_id: ProxyIdentifier,
csi_driver_version: Option<String>,
) -> Result<(TcpStream, Result<BindClientResponse, RpcError>), ConnectError> {
timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async {
let mut tcp_stream = TcpStream::connect(mount_target_addr).await?;
let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tcp_stream).await;
let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tcp_stream, csi_driver_version).await;
Ok((configure_stream(tcp_stream), response))
})
.await
Expand All @@ -285,9 +287,10 @@ impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {
async fn establish_connection(
&self,
proxy_id: ProxyIdentifier,
csi_driver_version: Option<String>,
) -> Result<(TcpStream, Option<PartitionId>, Option<ScaleUpConfig>), ConnectError> {
let (s, bind_result) =
Self::establish_plain_text_connection(self.mount_target_addr.clone(), proxy_id).await?;
Self::establish_plain_text_connection(self.mount_target_addr.clone(), proxy_id, csi_driver_version).await?;
match bind_result {
Ok(response) => {
debug!(
Expand All @@ -313,7 +316,7 @@ impl PartitionFinder<TcpStream> for PlainTextPartitionFinder {
proxy_id: ProxyIdentifier,
) -> JoinHandle<Result<(TcpStream, Result<BindClientResponse, RpcError>), ConnectError>> {
let addr = self.mount_target_addr.clone();
tokio::spawn(Self::establish_plain_text_connection(addr, proxy_id))
tokio::spawn(Self::establish_plain_text_connection(addr, proxy_id, None))
}
}

Expand All @@ -329,10 +332,11 @@ impl TlsPartitionFinder {
async fn establish_tls_connection(
tls_config: TlsConfig,
proxy_id: ProxyIdentifier,
csi_driver_version: Option<String>,
) -> Result<(TlsStream<TcpStream>, Result<BindClientResponse, RpcError>), ConnectError> {
timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async {
let mut tls_stream = establish_tls_stream(tls_config).await?;
let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tls_stream).await;
let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tls_stream, csi_driver_version).await;
Ok((tls_stream, response))
})
.await
Expand All @@ -345,6 +349,7 @@ impl PartitionFinder<TlsStream<TcpStream>> for TlsPartitionFinder {
async fn establish_connection(
&self,
proxy_id: ProxyIdentifier,
csi_driver_version: Option<String>,
) -> Result<
(
TlsStream<TcpStream>,
Expand All @@ -354,7 +359,7 @@ impl PartitionFinder<TlsStream<TcpStream>> for TlsPartitionFinder {
ConnectError,
> {
let tls_config_copy = self.tls_config.lock().await.clone();
let (s, bind_result) = Self::establish_tls_connection(tls_config_copy, proxy_id).await?;
let (s, bind_result) = Self::establish_tls_connection(tls_config_copy, proxy_id, csi_driver_version).await?;
let (bind_response, scale_up_config) = match bind_result {
Ok(response) => {
warn!(
Expand Down Expand Up @@ -383,7 +388,7 @@ impl PartitionFinder<TlsStream<TcpStream>> for TlsPartitionFinder {
Result<(TlsStream<TcpStream>, Result<BindClientResponse, RpcError>), ConnectError>,
> {
let tls_config_copy = self.tls_config.lock().await.clone();
tokio::spawn(Self::establish_tls_connection(tls_config_copy, proxy_id))
tokio::spawn(Self::establish_tls_connection(tls_config_copy, proxy_id, None))
}
}

Expand Down Expand Up @@ -416,7 +421,7 @@ mod tests {
let partition_finder = PlainTextPartitionFinder {
mount_target_addr: format!("127.0.0.1:{}", port.clone()),
};
partition_finder.establish_connection(PROXY_ID).await
partition_finder.establish_connection(PROXY_ID, None).await
})
.await
.expect("join err");
Expand Down Expand Up @@ -488,6 +493,7 @@ mod tests {
async fn establish_connection(
&self,
_proxy_id: ProxyIdentifier,
_csi_driver_version: Option<String>,
) -> Result<(TcpStream, Option<PartitionId>, Option<ScaleUpConfig>), ConnectError> {
unimplemented!()
}
Expand Down
5 changes: 4 additions & 1 deletion src/proxy/src/controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,15 @@ pub struct Controller<S: ProxyStream> {
pub restart_count: u64,
pub scale_up_config: ScaleUpConfig,
pub status_reporter: StatusReporter,
pub csi_driver_version: Option<String>,
}

impl<S: ProxyStream> Controller<S> {
pub async fn new(
listen_addr: &str,
partition_finder: Arc<impl PartitionFinder<S> + Sync + Send + 'static>,
status_reporter: StatusReporter,
csi_driver_version: Option<String>,
) -> Self {
let Ok(listener) = TcpListener::bind(listen_addr).await else {
panic!("Failed to bind {}", listen_addr);
Expand All @@ -97,6 +99,7 @@ impl<S: ProxyStream> Controller<S> {
restart_count: 0,
scale_up_config: DEFAULT_SCALE_UP_CONFIG,
status_reporter,
csi_driver_version: csi_driver_version,
}
}

Expand Down Expand Up @@ -145,7 +148,7 @@ impl<S: ProxyStream> Controller<S> {
None => {
match self
.partition_finder
.establish_connection(self.proxy_id)
.establish_connection(self.proxy_id, self.csi_driver_version.clone())
.await
{
Ok((s, partition_id, scale_up_config)) => {
Expand Down
5 changes: 5 additions & 0 deletions src/proxy/src/efs_prot.x
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
const PROXY_ID_LENGTH = 16;
const PROXY_INCARNATION_LENGTH = 8;
const PARTITION_ID_LENGTH = 64;
const CSI_DRIVER_VERSION_LEN = 32;

enum OperationType {
OP_BIND_CLIENT_TO_PARTITION = 1
Expand All @@ -17,6 +18,10 @@ struct ProxyIdentifier {
opaque incarnation<PROXY_INCARNATION_LENGTH>;
};

struct ConnectionMetrics {
opaque csi_driver_version<CSI_DRIVER_VERSION_LEN>;
};

struct ScaleUpConfig {
int max_multiplexed_connections;
int scale_up_bytes_per_sec_threshold;
Expand Down
36 changes: 32 additions & 4 deletions src/proxy/src/efs_rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::efs_prot::{BindClientResponse, OperationType};
use crate::error::RpcError;
use crate::proxy_identifier::ProxyIdentifier;
use crate::rpc;
use log::info;

pub const EFS_PROGRAM_NUMBER: u32 = 100200;
pub const EFS_PROGRAM_VERSION: u32 = 1;
Expand All @@ -19,8 +20,9 @@ pub struct PartitionId {
pub async fn bind_client_to_partition(
proxy_id: ProxyIdentifier,
stream: &mut dyn ProxyStream,
csi_driver_version: Option<String>,
) -> Result<BindClientResponse, RpcError> {
let request = create_bind_client_to_partition_request(&proxy_id)?;
let request = create_bind_client_to_partition_request(&proxy_id, csi_driver_version)?;
stream.write_all(&request).await?;
stream.flush().await?;

Expand All @@ -32,13 +34,24 @@ pub async fn bind_client_to_partition(

pub fn create_bind_client_to_partition_request(
proxy_id: &ProxyIdentifier,
csi_driver_version: Option<String>
) -> Result<Vec<u8>, RpcError> {
let payload = efs_prot::ProxyIdentifier {
identifier: proxy_id.uuid.as_bytes().to_vec(),
incarnation: proxy_id.incarnation.to_be_bytes().to_vec(),
};
let mut payload_buf = Vec::new();
xdr_codec::pack(&payload, &mut payload_buf)?;
match csi_driver_version {
Some(version) => {
let connection_metrics = efs_prot::ConnectionMetrics {
csi_driver_version: version.as_bytes().to_vec(),
};
xdr_codec::pack(&connection_metrics, &mut payload_buf)?;
info!("CSI Driver Version from create bind client to partion: {}", version)
},
None => info!("CSI Driver Version fom create bind client to partion not provided."),
}

let call_body = onc_rpc::CallBody::new(
EFS_PROGRAM_NUMBER,
Expand Down Expand Up @@ -96,16 +109,31 @@ pub mod tests {
#[test]
fn test_request_serde() -> Result<(), RpcError> {
let proxy_id = ProxyIdentifier::new();
let request = create_bind_client_to_partition_request(&proxy_id)?;
let csi_driver_version = Some("v9.9.9".to_string());
let request = create_bind_client_to_partition_request(&proxy_id, csi_driver_version.clone())?;

let deserialized = onc_rpc::RpcMessage::try_from(request.as_slice())?;
let deserialized_proxy_id = parse_bind_client_to_partition_request(&deserialized)?;
let (deserialized_proxy_id, deserialized_metrics) = parse_bind_client_to_partition_request(&deserialized)?;

assert_eq!(proxy_id.uuid, deserialized_proxy_id.uuid);
assert_eq!(proxy_id.incarnation, deserialized_proxy_id.incarnation);
assert_eq!(deserialized_metrics.csi_driver_version, csi_driver_version.unwrap_or_default().as_bytes().to_vec());
Ok(())
}

#[test]
fn test_request_serde_with_no_driver_version() -> Result<(), RpcError> {
let proxy_id = ProxyIdentifier::new();
let request = create_bind_client_to_partition_request(&proxy_id, None)?;

let deserialized = onc_rpc::RpcMessage::try_from(request.as_slice())?;
let deserialized_proxy_id = parse_bind_client_to_partition_request_with_no_driver_version(&deserialized)?;

assert_eq!(proxy_id.uuid, deserialized_proxy_id.uuid);
assert_eq!(proxy_id.incarnation, deserialized_proxy_id.incarnation);
Ok(())
}

#[test]
fn test_response_serde() -> Result<(), RpcError> {
let partition_id = generate_partition_id();
Expand All @@ -129,7 +157,7 @@ pub mod tests {
#[test]
fn test_parse_bind_client_to_partition_response_missing_reply() -> Result<(), RpcError> {
// Create a call message, which will error when parsed as a response
let malformed_response = create_bind_client_to_partition_request(&ProxyIdentifier::new())?;
let malformed_response = create_bind_client_to_partition_request(&ProxyIdentifier::new(), None)?;
let deserialized = onc_rpc::RpcMessage::try_from(malformed_response.as_slice())?;

let result = parse_bind_client_to_partition_response(&deserialized);
Expand Down
2 changes: 2 additions & 0 deletions src/proxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ async fn main() {
&proxy_config.nested_config.listen_addr,
Arc::new(TlsPartitionFinder::new(tls_config)),
status_reporter,
proxy_config.csi_driver_version,
)
.await;
tokio::spawn(controller.run(sigterm_cancellation_token.clone()))
Expand All @@ -86,6 +87,7 @@ async fn main() {
mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(),
}),
status_reporter,
proxy_config.csi_driver_version,
)
.await;
tokio::spawn(controller.run(sigterm_cancellation_token.clone()))
Expand Down
35 changes: 35 additions & 0 deletions src/proxy/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,41 @@ pub fn generate_partition_id() -> efs_prot::PartitionId {

pub fn parse_bind_client_to_partition_request(
request: &onc_rpc::RpcMessage<&[u8], &[u8]>,
) -> Result<(ProxyIdentifier, efs_prot::ConnectionMetrics), RpcError> {
let call_body = request.call_body().expect("not a call rpc");

if EFS_PROGRAM_NUMBER != call_body.program()
|| EFS_PROGRAM_VERSION != call_body.program_version()
{
return Err(RpcError::GarbageArgs);
}

let mut payload = Cursor::new(call_body.payload());
let raw_proxy_id = xdr_codec::unpack::<_, efs_prot::ProxyIdentifier>(&mut payload)?;
let connection_metrics = xdr_codec::unpack::<_, efs_prot::ConnectionMetrics>(&mut payload)?;

Ok((
ProxyIdentifier {
uuid: uuid::Builder::from_bytes(
raw_proxy_id
.identifier
.try_into()
.expect("Failed not convert vec to sized array"),
)
.into_uuid(),
incarnation: i64::from_be_bytes(
raw_proxy_id
.incarnation
.try_into()
.expect("Failed to convert vec to sized array"),
),
},
connection_metrics
))
}

pub fn parse_bind_client_to_partition_request_with_no_driver_version(
request: &onc_rpc::RpcMessage<&[u8], &[u8]>,
) -> Result<ProxyIdentifier, RpcError> {
let call_body = request.call_body().expect("not a call rpc");

Expand Down