diff --git a/src/mount_efs/__init__.py b/src/mount_efs/__init__.py index 43d9b662..b796ea41 100755 --- a/src/mount_efs/__init__.py +++ b/src/mount_efs/__init__.py @@ -1472,6 +1472,11 @@ def write_stunnel_config_file( system_release_version = get_system_release_version() global_config = dict(STUNNEL_GLOBAL_CONFIG) + if (config.has_option(CONFIG_SECTION, "csi_driver_version")): + global_config["csi_driver_version"] = config.get( + CONFIG_SECTION, "csi_driver_version" + ) + if not efs_proxy_enabled and is_stunnel_option_supported( stunnel_options, b"foreground", b"quiet", emit_warning_log=False ): diff --git a/src/proxy/src/config_parser.rs b/src/proxy/src/config_parser.rs index 0a49fb14..927288af 100644 --- a/src/proxy/src/config_parser.rs +++ b/src/proxy/src/config_parser.rs @@ -41,6 +41,9 @@ pub struct ProxyConfig { /// This nested structure is required for backwards compatibility #[serde(alias = "efs")] pub nested_config: EfsConfig, + + #[serde(alias = "csi_driver_version")] + pub csi_driver_version: Option, } impl FromStr for ProxyConfig { @@ -109,6 +112,7 @@ output = /var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log pid = /var/run/efs/fs-12341234.home.ec2-user.efs.21036+/stunnel.pid port = 8081 initial_partition_ip = 127.0.0.1:2049 +csi_driver_version = v9.9.9 [efs] accept = 127.0.0.1:21036 @@ -136,6 +140,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com output: Some(String::from( "/var/log/amazon/efs/fs-12341234.home.ec2-user.efs.21036.efs-proxy.log", )), + csi_driver_version: Some("v9.9.9".to_string()), nested_config: EfsConfig { listen_addr: String::from("127.0.0.1:21036"), mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"), @@ -187,6 +192,7 @@ checkHost = fs-12341234.efs.us-east-1.amazonaws.com ), debug: DEFAULT_LOG_LEVEL.to_string(), output: None, + csi_driver_version: None, nested_config: EfsConfig { listen_addr: String::from("127.0.0.1:21036"), mount_target_addr: String::from("fs-12341234.efs.us-east-1.amazonaws.com:2049"), diff --git a/src/proxy/src/connections.rs b/src/proxy/src/connections.rs index aca91c39..4a644d23 100644 --- a/src/proxy/src/connections.rs +++ b/src/proxy/src/connections.rs @@ -34,6 +34,7 @@ pub trait PartitionFinder { async fn establish_connection( &self, proxy_id: ProxyIdentifier, + csi_driver_version: Option, ) -> Result<(S, Option, Option), ConnectError>; async fn spawn_establish_connection_task( @@ -269,10 +270,11 @@ impl PlainTextPartitionFinder { async fn establish_plain_text_connection( mount_target_addr: String, proxy_id: ProxyIdentifier, + csi_driver_version: Option, ) -> Result<(TcpStream, Result), ConnectError> { timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async { let mut tcp_stream = TcpStream::connect(mount_target_addr).await?; - let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tcp_stream).await; + let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tcp_stream, csi_driver_version).await; Ok((configure_stream(tcp_stream), response)) }) .await @@ -285,9 +287,10 @@ impl PartitionFinder for PlainTextPartitionFinder { async fn establish_connection( &self, proxy_id: ProxyIdentifier, + csi_driver_version: Option, ) -> Result<(TcpStream, Option, Option), ConnectError> { let (s, bind_result) = - Self::establish_plain_text_connection(self.mount_target_addr.clone(), proxy_id).await?; + Self::establish_plain_text_connection(self.mount_target_addr.clone(), proxy_id, csi_driver_version).await?; match bind_result { Ok(response) => { debug!( @@ -313,7 +316,7 @@ impl PartitionFinder for PlainTextPartitionFinder { proxy_id: ProxyIdentifier, ) -> JoinHandle), ConnectError>> { let addr = self.mount_target_addr.clone(); - tokio::spawn(Self::establish_plain_text_connection(addr, proxy_id)) + tokio::spawn(Self::establish_plain_text_connection(addr, proxy_id, None)) } } @@ -329,10 +332,11 @@ impl TlsPartitionFinder { async fn establish_tls_connection( tls_config: TlsConfig, proxy_id: ProxyIdentifier, + csi_driver_version: Option, ) -> Result<(TlsStream, Result), ConnectError> { timeout(Duration::from_secs(SINGLE_CONNECTION_TIMEOUT_SEC), async { let mut tls_stream = establish_tls_stream(tls_config).await?; - let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tls_stream).await; + let response = efs_rpc::bind_client_to_partition(proxy_id, &mut tls_stream, csi_driver_version).await; Ok((tls_stream, response)) }) .await @@ -345,6 +349,7 @@ impl PartitionFinder> for TlsPartitionFinder { async fn establish_connection( &self, proxy_id: ProxyIdentifier, + csi_driver_version: Option, ) -> Result< ( TlsStream, @@ -354,7 +359,7 @@ impl PartitionFinder> for TlsPartitionFinder { ConnectError, > { let tls_config_copy = self.tls_config.lock().await.clone(); - let (s, bind_result) = Self::establish_tls_connection(tls_config_copy, proxy_id).await?; + let (s, bind_result) = Self::establish_tls_connection(tls_config_copy, proxy_id, csi_driver_version).await?; let (bind_response, scale_up_config) = match bind_result { Ok(response) => { warn!( @@ -383,7 +388,7 @@ impl PartitionFinder> for TlsPartitionFinder { Result<(TlsStream, Result), ConnectError>, > { let tls_config_copy = self.tls_config.lock().await.clone(); - tokio::spawn(Self::establish_tls_connection(tls_config_copy, proxy_id)) + tokio::spawn(Self::establish_tls_connection(tls_config_copy, proxy_id, None)) } } @@ -416,7 +421,7 @@ mod tests { let partition_finder = PlainTextPartitionFinder { mount_target_addr: format!("127.0.0.1:{}", port.clone()), }; - partition_finder.establish_connection(PROXY_ID).await + partition_finder.establish_connection(PROXY_ID, None).await }) .await .expect("join err"); @@ -488,6 +493,7 @@ mod tests { async fn establish_connection( &self, _proxy_id: ProxyIdentifier, + _csi_driver_version: Option, ) -> Result<(TcpStream, Option, Option), ConnectError> { unimplemented!() } diff --git a/src/proxy/src/controller.rs b/src/proxy/src/controller.rs index 1b00756c..870542d7 100644 --- a/src/proxy/src/controller.rs +++ b/src/proxy/src/controller.rs @@ -77,6 +77,7 @@ pub struct Controller { pub restart_count: u64, pub scale_up_config: ScaleUpConfig, pub status_reporter: StatusReporter, + pub csi_driver_version: Option, } impl Controller { @@ -84,6 +85,7 @@ impl Controller { listen_addr: &str, partition_finder: Arc + Sync + Send + 'static>, status_reporter: StatusReporter, + csi_driver_version: Option, ) -> Self { let Ok(listener) = TcpListener::bind(listen_addr).await else { panic!("Failed to bind {}", listen_addr); @@ -97,6 +99,7 @@ impl Controller { restart_count: 0, scale_up_config: DEFAULT_SCALE_UP_CONFIG, status_reporter, + csi_driver_version: csi_driver_version, } } @@ -145,7 +148,7 @@ impl Controller { None => { match self .partition_finder - .establish_connection(self.proxy_id) + .establish_connection(self.proxy_id, self.csi_driver_version.clone()) .await { Ok((s, partition_id, scale_up_config)) => { diff --git a/src/proxy/src/efs_prot.x b/src/proxy/src/efs_prot.x index d0faeb4f..e0773a31 100644 --- a/src/proxy/src/efs_prot.x +++ b/src/proxy/src/efs_prot.x @@ -5,6 +5,7 @@ const PROXY_ID_LENGTH = 16; const PROXY_INCARNATION_LENGTH = 8; const PARTITION_ID_LENGTH = 64; +const CSI_DRIVER_VERSION_LEN = 32; enum OperationType { OP_BIND_CLIENT_TO_PARTITION = 1 @@ -17,6 +18,10 @@ struct ProxyIdentifier { opaque incarnation; }; +struct ConnectionMetrics { + opaque csi_driver_version; +}; + struct ScaleUpConfig { int max_multiplexed_connections; int scale_up_bytes_per_sec_threshold; diff --git a/src/proxy/src/efs_rpc.rs b/src/proxy/src/efs_rpc.rs index 72fa789f..e3113b0f 100644 --- a/src/proxy/src/efs_rpc.rs +++ b/src/proxy/src/efs_rpc.rs @@ -7,6 +7,7 @@ use crate::efs_prot::{BindClientResponse, OperationType}; use crate::error::RpcError; use crate::proxy_identifier::ProxyIdentifier; use crate::rpc; +use log::info; pub const EFS_PROGRAM_NUMBER: u32 = 100200; pub const EFS_PROGRAM_VERSION: u32 = 1; @@ -19,8 +20,9 @@ pub struct PartitionId { pub async fn bind_client_to_partition( proxy_id: ProxyIdentifier, stream: &mut dyn ProxyStream, + csi_driver_version: Option, ) -> Result { - let request = create_bind_client_to_partition_request(&proxy_id)?; + let request = create_bind_client_to_partition_request(&proxy_id, csi_driver_version)?; stream.write_all(&request).await?; stream.flush().await?; @@ -32,6 +34,7 @@ pub async fn bind_client_to_partition( pub fn create_bind_client_to_partition_request( proxy_id: &ProxyIdentifier, + csi_driver_version: Option ) -> Result, RpcError> { let payload = efs_prot::ProxyIdentifier { identifier: proxy_id.uuid.as_bytes().to_vec(), @@ -39,6 +42,16 @@ pub fn create_bind_client_to_partition_request( }; let mut payload_buf = Vec::new(); xdr_codec::pack(&payload, &mut payload_buf)?; + match csi_driver_version { + Some(version) => { + let connection_metrics = efs_prot::ConnectionMetrics { + csi_driver_version: version.as_bytes().to_vec(), + }; + xdr_codec::pack(&connection_metrics, &mut payload_buf)?; + info!("CSI Driver Version from create bind client to partion: {}", version) + }, + None => info!("CSI Driver Version fom create bind client to partion not provided."), + } let call_body = onc_rpc::CallBody::new( EFS_PROGRAM_NUMBER, @@ -96,16 +109,31 @@ pub mod tests { #[test] fn test_request_serde() -> Result<(), RpcError> { let proxy_id = ProxyIdentifier::new(); - let request = create_bind_client_to_partition_request(&proxy_id)?; + let csi_driver_version = Some("v9.9.9".to_string()); + let request = create_bind_client_to_partition_request(&proxy_id, csi_driver_version.clone())?; let deserialized = onc_rpc::RpcMessage::try_from(request.as_slice())?; - let deserialized_proxy_id = parse_bind_client_to_partition_request(&deserialized)?; + let (deserialized_proxy_id, deserialized_metrics) = parse_bind_client_to_partition_request(&deserialized)?; assert_eq!(proxy_id.uuid, deserialized_proxy_id.uuid); assert_eq!(proxy_id.incarnation, deserialized_proxy_id.incarnation); + assert_eq!(deserialized_metrics.csi_driver_version, csi_driver_version.unwrap_or_default().as_bytes().to_vec()); Ok(()) } + #[test] + fn test_request_serde_with_no_driver_version() -> Result<(), RpcError> { + let proxy_id = ProxyIdentifier::new(); + let request = create_bind_client_to_partition_request(&proxy_id, None)?; + + let deserialized = onc_rpc::RpcMessage::try_from(request.as_slice())?; + let deserialized_proxy_id = parse_bind_client_to_partition_request_with_no_driver_version(&deserialized)?; + + assert_eq!(proxy_id.uuid, deserialized_proxy_id.uuid); + assert_eq!(proxy_id.incarnation, deserialized_proxy_id.incarnation); + Ok(()) + } + #[test] fn test_response_serde() -> Result<(), RpcError> { let partition_id = generate_partition_id(); @@ -129,7 +157,7 @@ pub mod tests { #[test] fn test_parse_bind_client_to_partition_response_missing_reply() -> Result<(), RpcError> { // Create a call message, which will error when parsed as a response - let malformed_response = create_bind_client_to_partition_request(&ProxyIdentifier::new())?; + let malformed_response = create_bind_client_to_partition_request(&ProxyIdentifier::new(), None)?; let deserialized = onc_rpc::RpcMessage::try_from(malformed_response.as_slice())?; let result = parse_bind_client_to_partition_response(&deserialized); diff --git a/src/proxy/src/main.rs b/src/proxy/src/main.rs index 92d4d1e4..a6fdf9b9 100644 --- a/src/proxy/src/main.rs +++ b/src/proxy/src/main.rs @@ -76,6 +76,7 @@ async fn main() { &proxy_config.nested_config.listen_addr, Arc::new(TlsPartitionFinder::new(tls_config)), status_reporter, + proxy_config.csi_driver_version, ) .await; tokio::spawn(controller.run(sigterm_cancellation_token.clone())) @@ -86,6 +87,7 @@ async fn main() { mount_target_addr: proxy_config.nested_config.mount_target_addr.clone(), }), status_reporter, + proxy_config.csi_driver_version, ) .await; tokio::spawn(controller.run(sigterm_cancellation_token.clone())) diff --git a/src/proxy/src/test_utils.rs b/src/proxy/src/test_utils.rs index 9d21b2f2..e3a5a6a9 100644 --- a/src/proxy/src/test_utils.rs +++ b/src/proxy/src/test_utils.rs @@ -102,6 +102,41 @@ pub fn generate_partition_id() -> efs_prot::PartitionId { pub fn parse_bind_client_to_partition_request( request: &onc_rpc::RpcMessage<&[u8], &[u8]>, +) -> Result<(ProxyIdentifier, efs_prot::ConnectionMetrics), RpcError> { + let call_body = request.call_body().expect("not a call rpc"); + + if EFS_PROGRAM_NUMBER != call_body.program() + || EFS_PROGRAM_VERSION != call_body.program_version() + { + return Err(RpcError::GarbageArgs); + } + + let mut payload = Cursor::new(call_body.payload()); + let raw_proxy_id = xdr_codec::unpack::<_, efs_prot::ProxyIdentifier>(&mut payload)?; + let connection_metrics = xdr_codec::unpack::<_, efs_prot::ConnectionMetrics>(&mut payload)?; + + Ok(( + ProxyIdentifier { + uuid: uuid::Builder::from_bytes( + raw_proxy_id + .identifier + .try_into() + .expect("Failed not convert vec to sized array"), + ) + .into_uuid(), + incarnation: i64::from_be_bytes( + raw_proxy_id + .incarnation + .try_into() + .expect("Failed to convert vec to sized array"), + ), + }, + connection_metrics + )) +} + +pub fn parse_bind_client_to_partition_request_with_no_driver_version( + request: &onc_rpc::RpcMessage<&[u8], &[u8]>, ) -> Result { let call_body = request.call_body().expect("not a call rpc");