Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assert gRPC calls #17

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/http_auth_random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ fn main() -> Result<()> {
http_auth_random
.call_proxy_on_http_call_response(http_context, 0, 0, buffer_data.len() as i32, 0)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(buffer_data))
.returning(Some(buffer_data.as_bytes()))
.expect_send_local_response(
Some(403),
Some("Access forbidden.\n"),
Expand Down
47 changes: 46 additions & 1 deletion src/expect_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl<'a> ExpectGetBufferBytes<'a> {
}
}

pub fn returning(&mut self, buffer_data: Option<&str>) -> &mut Tester {
pub fn returning(&mut self, buffer_data: Option<&[u8]>) -> &mut Tester {
self.tester
.get_expect_handle()
.staged
Expand Down Expand Up @@ -150,3 +150,48 @@ impl<'a> ExpectHttpCall<'a> {
self.tester
}
}

pub struct ExpectGrpcCall<'a> {
tester: &'a mut Tester,
service: Option<&'a str>,
service_name: Option<&'a str>,
method_name: Option<&'a str>,
initial_metadata: Option<&'a [u8]>,
request: Option<&'a [u8]>,
timeout: Option<u64>,
}

impl<'a> ExpectGrpcCall<'a> {
pub fn expecting(
tester: &'a mut Tester,
service: Option<&'a str>,
service_name: Option<&'a str>,
method_name: Option<&'a str>,
initial_metadata: Option<&'a [u8]>,
request: Option<&'a [u8]>,
timeout: Option<u64>,
) -> Self {
Self {
tester,
service,
service_name,
method_name,
initial_metadata,
request,
timeout,
}
}

pub fn returning(&mut self, token_id: Option<u32>) -> &mut Tester {
self.tester.get_expect_handle().staged.set_expect_grpc_call(
self.service,
self.service_name,
self.method_name,
self.initial_metadata,
self.request,
self.timeout,
token_id,
);
self.tester
}
}
98 changes: 88 additions & 10 deletions src/expectations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,20 @@ impl ExpectHandle {
self.staged = Expect::new(allow_unexpected);
}

pub fn assert_stage(&self) {
pub fn assert_stage(&self) -> Option<String> {
if self.staged.expect_count > 0 {
panic!(
return Some(format!(
"Error: failed to consume all expectations - total remaining: {}",
self.staged.expect_count
);
));
} else if self.staged.expect_count < 0 {
panic!(
return Some(format!(
"Error: expectations failed to account for all host calls by {} \n\
if this is intended, please use --allow-unexpected (-a) mode",
-1 * self.staged.expect_count
);
));
}
None
}

pub fn print_staged(&self) {
Expand Down Expand Up @@ -86,6 +87,15 @@ pub struct Expect {
Option<Duration>,
Option<u32>,
)>,
grpc_call: Vec<(
Option<String>,
Option<String>,
Option<String>,
Option<Bytes>,
Option<Bytes>,
Option<Duration>,
Option<u32>,
)>,
}

impl Expect {
Expand All @@ -106,6 +116,7 @@ impl Expect {
add_header_map_value: vec![],
send_local_response: vec![],
http_call: vec![],
grpc_call: vec![],
}
}

Expand Down Expand Up @@ -190,13 +201,11 @@ impl Expect {
pub fn set_expect_get_buffer_bytes(
&mut self,
buffer_type: Option<i32>,
buffer_data: Option<&str>,
buffer_data: Option<&[u8]>,
) {
self.expect_count += 1;
self.get_buffer_bytes.push((
buffer_type,
buffer_data.map(|data| data.as_bytes().to_vec()),
));
self.get_buffer_bytes
.push((buffer_type, buffer_data.map(|data| data.to_vec())));
}

pub fn get_expect_get_buffer_bytes(&mut self, buffer_type: i32) -> Option<Bytes> {
Expand Down Expand Up @@ -571,4 +580,73 @@ impl Expect {
}
}
}

pub fn set_expect_grpc_call(
&mut self,
service: Option<&str>,
service_name: Option<&str>,
method_name: Option<&str>,
initial_metadata: Option<&[u8]>,
request: Option<&[u8]>,
timeout: Option<u64>,
token_id: Option<u32>,
) {
self.expect_count += 1;
self.grpc_call.push((
service.map(ToString::to_string),
service_name.map(ToString::to_string),
method_name.map(ToString::to_string),
initial_metadata.map(|s| s.to_vec()),
request.map(|s| s.to_vec()),
timeout.map(Duration::from_millis),
token_id,
));
}

pub fn get_expect_grpc_call(
&mut self,
service: String,
service_name: String,
method: String,
initial_metadata: &[u8],
request: &[u8],
timeout: i32,
) -> Option<u32> {
match self.grpc_call.len() {
0 => {
if !self.allow_unexpected {
self.expect_count -= 1;
}
set_status(ExpectStatus::Unexpected);
None
}
_ => {
self.expect_count -= 1;
let (
expected_service,
expected_service_name,
expected_method,
expected_initial_metadata,
expected_request,
expected_duration,
result,
) = self.grpc_call.remove(0);

let expected = expected_service.map(|e| e == service).unwrap_or(true)
&& expected_service_name
.map(|e| e == service_name)
.unwrap_or(true)
&& expected_method.map(|e| e == method).unwrap_or(true)
&& expected_initial_metadata
.map(|e| e == initial_metadata)
.unwrap_or(true)
&& expected_request.map(|e| e == request).unwrap_or(true)
&& expected_duration
.map(|e| e.as_millis() as i32 == timeout)
.unwrap_or(true);
set_expect_status(expected);
return result;
}
}
}
}
4 changes: 4 additions & 0 deletions src/host_settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,5 +294,9 @@ pub fn default_buffer_bytes() -> HashMap<i32, Bytes> {
BufferType::HttpCallResponseBody as i32,
"default_call_response_body".as_bytes().to_vec(),
);
default_bytes.insert(
BufferType::GrpcReceiveBuffer as i32,
"default_grpc_receive_buffer".as_bytes().to_vec(),
);
default_bytes
}
90 changes: 75 additions & 15 deletions src/hostcalls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1390,20 +1390,67 @@ fn get_hostfunc(
"proxy_grpc_call" => {
Some(Func::wrap(
store,
|_caller: Caller<'_, ()>,
_service_ptr: i32,
_service_size: i32,
_service_name_ptr: i32,
_service_name_size: i32,
_method_name_ptr: i32,
_method_name_size: i32,
_initial_metadata_ptr: i32,
_initial_metadata_size: i32,
_request_ptr: i32,
_request_size: i32,
_timeout_milliseconds: i32,
_token_ptr: i32|
|mut caller: Caller<'_, ()>,
service_ptr: i32,
service_size: i32,
service_name_ptr: i32,
service_name_size: i32,
method_name_ptr: i32,
method_name_size: i32,
initial_metadata_ptr: i32,
initial_metadata_size: i32,
request_ptr: i32,
request_size: i32,
timeout_milliseconds: i32,
token_ptr: i32|
-> i32 {
print!("[vm->host] proxy_grpc_call({initial_metadata_ptr}, {initial_metadata_size})");

// Default Function: receives and displays http call from proxy-wasm module
// Expectation: asserts equal the receieved http call with the expected one
let mem = match caller.get_export("memory") {
Some(Extern::Memory(mem)) => mem,
_ => {
println!("Error: proxy_http_call cannot get export \"memory\"");
println!(
"[vm<-host] proxy_http_call(...) -> (return_token) return: {:?}",
Status::InternalFailure
);
return Status::InternalFailure as i32;
}
};

let service = read_string(&caller, mem, service_ptr, service_size);
let service_name =
read_string(&caller, mem, service_name_ptr, service_name_size);
let method_name = read_string(&caller, mem, method_name_ptr, method_name_size);
let initial_metadata =
read_bytes(&caller, mem, initial_metadata_ptr, initial_metadata_size)
.unwrap();
let request = read_bytes(&caller, mem, request_ptr, request_size).unwrap();

println!(
"[vm->host] proxy_grpc_call(service={service}, service_name={service_name}, method_name={method_name}, initial_metadata={initial_metadata:?}, request={request:?}, timeout={timeout_milliseconds}");

let token_id = match EXPECT.lock().unwrap().staged.get_expect_grpc_call(
service,
service_name,
method_name,
initial_metadata,
request,
timeout_milliseconds,
) {
Some(expect_token) => expect_token,
None => 0,
};

unsafe {
let return_token_add = mem.data_mut(&mut caller).get_unchecked_mut(
token_ptr as u32 as usize..token_ptr as u32 as usize + 4,
);
return_token_add.copy_from_slice(&token_id.to_le_bytes());
}

// Default Function:
// Expectation:
println!(
Expand All @@ -1412,9 +1459,10 @@ fn get_hostfunc(
);
println!(
"[vm<-host] proxy_grpc_call() -> (..) return: {:?}",
Status::InternalFailure
Status::Ok
);
return Status::InternalFailure as i32;
assert_ne!(get_status(), ExpectStatus::Failed);
return Status::Ok as i32;
},
))
}
Expand Down Expand Up @@ -1641,6 +1689,18 @@ fn get_hostfunc(
}
}

fn read_string(caller: &Caller<()>, mem: Memory, ptr: i32, size: i32) -> String {
read_bytes(caller, mem, ptr, size)
.map(String::from_utf8_lossy)
.unwrap()
.to_string()
}

fn read_bytes<'a>(caller: &'a Caller<()>, mem: Memory, ptr: i32, size: i32) -> Option<&'a [u8]> {
mem.data(caller)
.get(ptr as usize..ptr as usize + size as usize)
}

pub mod serial_utils {

type Bytes = Vec<u8>;
Expand Down
25 changes: 24 additions & 1 deletion src/tester.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,26 @@ impl Tester {
ExpectHttpCall::expecting(self, upstream, headers, body, trailers, timeout)
}

pub fn expect_grpc_call(
&mut self,
service: Option<&'static str>,
service_name: Option<&'static str>,
method_name: Option<&'static str>,
initial_metadata: Option<&'static [u8]>,
request: Option<&'static [u8]>,
timeout: Option<u64>,
) -> ExpectGrpcCall {
ExpectGrpcCall::expecting(
self,
service,
service_name,
method_name,
initial_metadata,
request,
timeout,
)
}

/* ------------------------------------- High-level Expectation Setting ------------------------------------- */

pub fn set_quiet(&mut self, quiet: bool) {
Expand Down Expand Up @@ -323,7 +343,10 @@ impl Tester {
}

fn assert_expect_stage(&mut self) {
self.expect.lock().unwrap().assert_stage();
let err = self.expect.lock().unwrap().assert_stage();
if let Some(msg) = err {
panic!("{}", msg)
}
}

pub fn get_settings_handle(&self) -> MutexGuard<HostHandle> {
Expand Down
3 changes: 3 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ pub enum BufferType {
DownstreamData = 2,
UpstreamData = 3,
HttpCallResponseBody = 4,
GrpcReceiveBuffer = 5,
VmConfiguration = 6,
PluginConfiguration = 7,
}

#[repr(u32)]
Expand Down
Loading