Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 136 additions & 8 deletions ldk-server-client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ use crate::error::LdkServerErrorCode::{

type StreamingClient = HyperClient<HttpsConnector<hyper::client::HttpConnector>, HyperBody>;

const GRPC_FRAME_HEADER_LEN: usize = 5;

// Applies to complete unary gRPC responses. The server applies the same cap to unary request
// bodies before protobuf decoding.
const MAX_GRPC_UNARY_RESPONSE_LEN: usize = 10 * 1024 * 1024;

// Applies to each server-streaming gRPC message. Graph RPCs use the unary client path and are not
// constrained by this limit.
const MAX_GRPC_STREAM_MESSAGE_LEN: usize = 4 * 1024 * 1024;

/// Client to access a hosted instance of LDK Server via gRPC.
///
/// The client requires the server's TLS certificate to be provided for verification.
Expand Down Expand Up @@ -450,10 +460,7 @@ impl LdkServerClient {
return Err(error);
}

// Read the response body
let payload = response.bytes().await.map_err(|e| {
LdkServerError::new(InternalError, format!("Failed to read response body: {}", e))
})?;
let payload = read_grpc_unary_response_body(response).await?;

let proto_bytes = decode_grpc_body(&payload)
.map_err(|e| LdkServerError::new(InternalError, e.message))?;
Expand Down Expand Up @@ -508,6 +515,42 @@ impl LdkServerClient {
}
}

async fn read_grpc_unary_response_body(
mut response: reqwest::Response,
) -> Result<Vec<u8>, LdkServerError> {
let capacity = if let Some(content_length) = response.content_length() {
check_grpc_unary_response_len(content_length)?;
content_length as usize
} else {
0
};

let mut payload = Vec::with_capacity(capacity);
while let Some(chunk) = response.chunk().await.map_err(|e| {
LdkServerError::new(InternalError, format!("Failed to read response body: {}", e))
})? {
let len = payload.len().checked_add(chunk.len()).ok_or_else(|| {
LdkServerError::new(InternalError, "gRPC unary response body length overflow")
})?;
check_grpc_unary_response_len(len as u64)?;
payload.extend_from_slice(&chunk);
}
Ok(payload)
}

fn check_grpc_unary_response_len(len: u64) -> Result<(), LdkServerError> {
if len > MAX_GRPC_UNARY_RESPONSE_LEN as u64 {
return Err(LdkServerError::new(
InternalError,
format!(
"gRPC unary response exceeds maximum size of {} bytes",
MAX_GRPC_UNARY_RESPONSE_LEN
),
));
}
Ok(())
}

/// Map a gRPC status code to an LdkServerError.
fn grpc_code_to_error(code: u32, message: String) -> LdkServerError {
match code {
Expand Down Expand Up @@ -568,19 +611,43 @@ impl<M: Message + Default> GrpcStream<M> {
pub async fn next_message(&mut self) -> Option<Result<M, LdkServerError>> {
loop {
// Try to decode a complete gRPC frame from the buffer
if self.buf.len() >= 5 {
if self.buf.len() >= GRPC_FRAME_HEADER_LEN {
if self.buf[0] != 0 {
return Some(Err(LdkServerError::new(
InternalError,
"gRPC stream compression is not supported",
)));
}
let msg_len =
u32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]])
as usize;
if self.buf.len() >= 5 + msg_len {
let proto_bytes = &self.buf[5..5 + msg_len];
if msg_len > MAX_GRPC_STREAM_MESSAGE_LEN {
return Some(Err(LdkServerError::new(
InternalError,
format!(
"gRPC stream message exceeds maximum size of {} bytes",
MAX_GRPC_STREAM_MESSAGE_LEN
),
)));
}
let frame_len = match GRPC_FRAME_HEADER_LEN.checked_add(msg_len) {
Some(frame_len) => frame_len,
None => {
return Some(Err(LdkServerError::new(
InternalError,
"gRPC stream frame length overflow",
)));
},
};
if self.buf.len() >= frame_len {
let proto_bytes = &self.buf[GRPC_FRAME_HEADER_LEN..frame_len];
let result = M::decode(proto_bytes).map_err(|e| {
LdkServerError::new(
InternalError,
format!("Failed to decode gRPC stream message: {}", e),
)
});
self.buf.drain(..5 + msg_len);
self.buf.drain(..frame_len);
return Some(result);
}
}
Expand Down Expand Up @@ -691,6 +758,25 @@ mod tests {
assert_eq!(err.message, "gRPC stream became unavailable: server shutting down");
}

#[test]
fn test_grpc_unary_response_len_allows_limit() {
assert!(check_grpc_unary_response_len(MAX_GRPC_UNARY_RESPONSE_LEN as u64).is_ok());
}

#[test]
fn test_grpc_unary_response_len_rejects_oversized() {
let err =
check_grpc_unary_response_len(MAX_GRPC_UNARY_RESPONSE_LEN as u64 + 1).unwrap_err();
assert_eq!(err.error_code, InternalError);
assert_eq!(
err.message,
format!(
"gRPC unary response exceeds maximum size of {} bytes",
MAX_GRPC_UNARY_RESPONSE_LEN
)
);
}

#[tokio::test]
async fn test_event_stream_surfaces_terminal_grpc_status() {
let (mut sender, body) = Body::channel();
Expand All @@ -713,6 +799,48 @@ mod tests {
assert!(stream.next_message().await.is_none());
}

#[tokio::test]
async fn test_event_stream_rejects_oversized_frame_header() {
let (mut sender, body) = Body::channel();
sender.send_data(vec![0u8, 0xff, 0xff, 0xff, 0xff].into()).await.unwrap();
drop(sender);

let mut stream: EventStream = GrpcStream {
body,
buf: Vec::new(),
trailers_checked: false,
_marker: std::marker::PhantomData,
};

let result = stream.next_message().await.unwrap().unwrap_err();
assert_eq!(result.error_code, InternalError);
assert_eq!(
result.message,
format!(
"gRPC stream message exceeds maximum size of {} bytes",
MAX_GRPC_STREAM_MESSAGE_LEN
)
);
}

#[tokio::test]
async fn test_event_stream_rejects_compressed_frame() {
let (mut sender, body) = Body::channel();
sender.send_data(vec![1u8, 0, 0, 0, 0].into()).await.unwrap();
drop(sender);

let mut stream: EventStream = GrpcStream {
body,
buf: Vec::new(),
trailers_checked: false,
_marker: std::marker::PhantomData,
};

let result = stream.next_message().await.unwrap().unwrap_err();
assert_eq!(result.error_code, InternalError);
assert_eq!(result.message, "gRPC stream compression is not supported");
}

#[test]
fn test_grpc_code_to_error_all_known_codes() {
let cases = [
Expand Down