Skip to content

Commit

Permalink
assert grpc calls
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsnaps committed Aug 11, 2023
1 parent 1dca120 commit 1797355
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 27 deletions.
47 changes: 46 additions & 1 deletion src/expect_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'a> ExpectGetBufferBytes<'a> {
}
}

pub fn returning(&mut self, buffer_data: Option<&str>) -> &mut Tester {
pub fn returning(&mut self, buffer_data: Option<&[u8]>) -> &mut Tester {
self.tester
.get_expect_handle()
.staged
Expand Down Expand Up @@ -150,3 +150,48 @@ impl<'a> ExpectHttpCall<'a> {
self.tester
}
}

pub struct ExpectGrpcCall<'a> {
tester: &'a mut Tester,
service: Option<&'a str>,
service_name: Option<&'a str>,
method_name: Option<&'a str>,
initial_metadata: Option<&'a [u8]>,
request: Option<&'a [u8]>,
timeout: Option<u64>,
}

impl<'a> ExpectGrpcCall<'a> {
pub fn expecting(
tester: &'a mut Tester,
service: Option<&'a str>,
service_name: Option<&'a str>,
method_name: Option<&'a str>,
initial_metadata: Option<&'a [u8]>,
request: Option<&'a [u8]>,
timeout: Option<u64>,
) -> Self {
Self {
tester,
service,
service_name,
method_name,
initial_metadata,
request,
timeout,
}
}

pub fn returning(&mut self, token_id: Option<u32>) -> &mut Tester {
self.tester.get_expect_handle().staged.set_expect_grpc_call(
self.service,
self.service_name,
self.method_name,
self.initial_metadata,
self.request,
self.timeout,
token_id,
);
self.tester
}
}
98 changes: 88 additions & 10 deletions src/expectations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ impl ExpectHandle {
self.staged = Expect::new(allow_unexpected);
}

pub fn assert_stage(&self) {
pub fn assert_stage(&self) -> Option<String> {
if self.staged.expect_count > 0 {
panic!(
return Some(format!(
"Error: failed to consume all expectations - total remaining: {}",
self.staged.expect_count
);
));
} else if self.staged.expect_count < 0 {
panic!(
return Some(format!(
"Error: expectations failed to account for all host calls by {} \n\
if this is intended, please use --allow-unexpected (-a) mode",
-1 * self.staged.expect_count
);
));
}
None
}

pub fn print_staged(&self) {
Expand Down Expand Up @@ -86,6 +87,15 @@ pub struct Expect {
Option<Duration>,
Option<u32>,
)>,
grpc_call: Vec<(
Option<String>,
Option<String>,
Option<String>,
Option<Bytes>,
Option<Bytes>,
Option<Duration>,
Option<u32>,
)>,
}

impl Expect {
Expand All @@ -106,6 +116,7 @@ impl Expect {
add_header_map_value: vec![],
send_local_response: vec![],
http_call: vec![],
grpc_call: vec![],
}
}

Expand Down Expand Up @@ -190,13 +201,11 @@ impl Expect {
pub fn set_expect_get_buffer_bytes(
&mut self,
buffer_type: Option<i32>,
buffer_data: Option<&str>,
buffer_data: Option<&[u8]>,
) {
self.expect_count += 1;
self.get_buffer_bytes.push((
buffer_type,
buffer_data.map(|data| data.as_bytes().to_vec()),
));
self.get_buffer_bytes
.push((buffer_type, buffer_data.map(|data| data.to_vec())));
}

pub fn get_expect_get_buffer_bytes(&mut self, buffer_type: i32) -> Option<Bytes> {
Expand Down Expand Up @@ -571,4 +580,73 @@ impl Expect {
}
}
}

pub fn set_expect_grpc_call(
&mut self,
service: Option<&str>,
service_name: Option<&str>,
method_name: Option<&str>,
initial_metadata: Option<&[u8]>,
request: Option<&[u8]>,
timeout: Option<u64>,
token_id: Option<u32>,
) {
self.expect_count += 1;
self.grpc_call.push((
service.map(ToString::to_string),
service_name.map(ToString::to_string),
method_name.map(ToString::to_string),
initial_metadata.map(|s| s.to_vec()),
request.map(|s| s.to_vec()),
timeout.map(Duration::from_millis),
token_id,
));
}

pub fn get_expect_grpc_call(
&mut self,
service: String,
service_name: String,
method: String,
initial_metadata: &[u8],
request: &[u8],
timeout: i32,
) -> Option<u32> {
match self.grpc_call.len() {
0 => {
if !self.allow_unexpected {
self.expect_count -= 1;
}
set_status(ExpectStatus::Unexpected);
None
}
_ => {
self.expect_count -= 1;
let (
expected_service,
expected_service_name,
expected_method,
expected_initial_metadata,
expected_request,
expected_duration,
result,
) = self.grpc_call.remove(0);

let expected = expected_service.map(|e| e == service).unwrap_or(true)
&& expected_service_name
.map(|e| e == service_name)
.unwrap_or(true)
&& expected_method.map(|e| e == method).unwrap_or(true)
&& expected_initial_metadata
.map(|e| e == initial_metadata)
.unwrap_or(true)
&& expected_request.map(|e| e == request).unwrap_or(true)
&& expected_duration
.map(|e| e.as_millis() as i32 == timeout)
.unwrap_or(true);
set_expect_status(expected);
return result;
}
}
}
}
4 changes: 4 additions & 0 deletions src/host_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,9 @@ pub fn default_buffer_bytes() -> HashMap<i32, Bytes> {
BufferType::HttpCallResponseBody as i32,
"default_call_response_body".as_bytes().to_vec(),
);
default_bytes.insert(
BufferType::GrpcReceiveBuffer as i32,
"default_grpc_receive_buffer".as_bytes().to_vec(),
);
default_bytes
}
90 changes: 75 additions & 15 deletions src/hostcalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1390,20 +1390,67 @@ fn get_hostfunc(
"proxy_grpc_call" => {
Some(Func::wrap(
store,
|_caller: Caller<'_, ()>,
_service_ptr: i32,
_service_size: i32,
_service_name_ptr: i32,
_service_name_size: i32,
_method_name_ptr: i32,
_method_name_size: i32,
_initial_metadata_ptr: i32,
_initial_metadata_size: i32,
_request_ptr: i32,
_request_size: i32,
_timeout_milliseconds: i32,
_token_ptr: i32|
|mut caller: Caller<'_, ()>,
service_ptr: i32,
service_size: i32,
service_name_ptr: i32,
service_name_size: i32,
method_name_ptr: i32,
method_name_size: i32,
initial_metadata_ptr: i32,
initial_metadata_size: i32,
request_ptr: i32,
request_size: i32,
timeout_milliseconds: i32,
token_ptr: i32|
-> i32 {
print!("[vm->host] proxy_grpc_call({initial_metadata_ptr}, {initial_metadata_size})");

// Default Function: receives and displays http call from proxy-wasm module
// Expectation: asserts equal the receieved http call with the expected one
let mem = match caller.get_export("memory") {
Some(Extern::Memory(mem)) => mem,
_ => {
println!("Error: proxy_http_call cannot get export \"memory\"");
println!(
"[vm<-host] proxy_http_call(...) -> (return_token) return: {:?}",
Status::InternalFailure
);
return Status::InternalFailure as i32;
}
};

let service = read_string(&caller, mem, service_ptr, service_size);
let service_name =
read_string(&caller, mem, service_name_ptr, service_name_size);
let method_name = read_string(&caller, mem, method_name_ptr, method_name_size);
let initial_metadata =
read_bytes(&caller, mem, initial_metadata_ptr, initial_metadata_size)
.unwrap();
let request = read_bytes(&caller, mem, request_ptr, request_size).unwrap();

println!(
"[vm->host] proxy_grpc_call(service={service}, service_name={service_name}, method_name={method_name}, initial_metadata={initial_metadata:?}, request={request:?}, timeout={timeout_milliseconds}");

let token_id = match EXPECT.lock().unwrap().staged.get_expect_grpc_call(
service,
service_name,
method_name,
initial_metadata,
request,
timeout_milliseconds,
) {
Some(expect_token) => expect_token,
None => 0,
};

unsafe {
let return_token_add = mem.data_mut(&mut caller).get_unchecked_mut(
token_ptr as u32 as usize..token_ptr as u32 as usize + 4,
);
return_token_add.copy_from_slice(&token_id.to_le_bytes());
}

// Default Function:
// Expectation:
println!(
Expand All @@ -1412,9 +1459,10 @@ fn get_hostfunc(
);
println!(
"[vm<-host] proxy_grpc_call() -> (..) return: {:?}",
Status::InternalFailure
Status::Ok
);
return Status::InternalFailure as i32;
assert_ne!(get_status(), ExpectStatus::Failed);
return Status::Ok as i32;
},
))
}
Expand Down Expand Up @@ -1641,6 +1689,18 @@ fn get_hostfunc(
}
}

fn read_string(caller: &Caller<()>, mem: Memory, ptr: i32, size: i32) -> String {
read_bytes(caller, mem, ptr, size)
.map(String::from_utf8_lossy)
.unwrap()
.to_string()
}

fn read_bytes<'a>(caller: &'a Caller<()>, mem: Memory, ptr: i32, size: i32) -> Option<&'a [u8]> {
mem.data(caller)
.get(ptr as usize..ptr as usize + size as usize)
}

pub mod serial_utils {

type Bytes = Vec<u8>;
Expand Down
25 changes: 24 additions & 1 deletion src/tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,26 @@ impl Tester {
ExpectHttpCall::expecting(self, upstream, headers, body, trailers, timeout)
}

pub fn expect_grpc_call(
&mut self,
service: Option<&'static str>,
service_name: Option<&'static str>,
method_name: Option<&'static str>,
initial_metadata: Option<&'static [u8]>,
request: Option<&'static [u8]>,
timeout: Option<u64>,
) -> ExpectGrpcCall {
ExpectGrpcCall::expecting(
self,
service,
service_name,
method_name,
initial_metadata,
request,
timeout,
)
}

/* ------------------------------------- High-level Expectation Setting ------------------------------------- */

pub fn set_quiet(&mut self, quiet: bool) {
Expand Down Expand Up @@ -323,7 +343,10 @@ impl Tester {
}

fn assert_expect_stage(&mut self) {
self.expect.lock().unwrap().assert_stage();
let err = self.expect.lock().unwrap().assert_stage();
if let Some(msg) = err {
panic!("{}", msg)
}
}

pub fn get_settings_handle(&self) -> MutexGuard<HostHandle> {
Expand Down
3 changes: 3 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ pub enum BufferType {
DownstreamData = 2,
UpstreamData = 3,
HttpCallResponseBody = 4,
GrpcReceiveBuffer = 5,
VmConfiguration = 6,
PluginConfiguration = 7,
}

#[repr(u32)]
Expand Down

0 comments on commit 1797355

Please sign in to comment.