Skip to content

Commit

Permalink
generalised wasm filter#1
Browse files Browse the repository at this point in the history
  • Loading branch information
pranav-bhatt committed Mar 19, 2021
1 parent fa36eee commit 7b4e81f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 40 deletions.
Binary file modified rate-limit-filter/pkg/rate_limit_filter_bg.wasm
Binary file not shown.
54 changes: 47 additions & 7 deletions rate-limit-filter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use serde::Deserialize;
use std::collections::HashMap;
use std::time::SystemTime;

// We need to make sure a HTTP root context is created and initialized when the filter is initialized.
// The _start() function initialises this root context
#[no_mangle]
pub fn _start() {
proxy_wasm::set_log_level(LogLevel::Info);
Expand All @@ -20,6 +22,7 @@ pub fn _start() {
});
}

// Defining standard CORS headers
static CORS_HEADERS: [(&str, &str); 5] = [
("Powered-By", "proxy-wasm"),
("Access-Control-Allow-Origin", "*"),
Expand All @@ -28,35 +31,44 @@ static CORS_HEADERS: [(&str, &str); 5] = [
("Access-Control-Max-Age", "3600"),
];

// This struct is what the JWT token sent by the user will deserialize to
#[derive(Deserialize, Debug)]
struct Data {
username: String,
plan: String,
}

// This is the instance of a call made. It sorta derives from the root context
#[derive(Debug)]
struct UpstreamCall {
config_json: HashMap<String, Rule>,
}

impl UpstreamCall {
// Takes in the HashMap created in the root context mapping path name to rule type
fn new(json_hm: &HashMap<String, Rule>) -> Self {
Self {
//TODO this clone is super heavy, find a way to get rid of it
config_json: json_hm.clone(),
}
}

// Check if the path specified in the incoming request's path header has rule type None.
// Returns Option containing path name that was sent
fn rule_is_none(&self, path: String) -> Option<String> {
let rule_vec = self.config_json.get(&path).unwrap();
if std::mem::discriminant(rule_vec) == std::mem::discriminant(&Rule::None) {
let rule = self.config_json.get(&path).unwrap();
// checking based only on type
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::None) {
return Some(path);
}
return None;
}

// Check if the path specified in the incoming request's path header has rule type RateLimiter.
// Returns Option containing vector of RateLimiterJson objects (list of plan names with limits)
fn rule_is_rate_limiter(&self, path: String) -> Option<Vec<RateLimiterJson>> {
// only meant to check if rule type is rate limiter
let rule = self.config_json.get(&path).unwrap();
// checking based only on type
if std::mem::discriminant(rule) == std::mem::discriminant(&Rule::RateLimiter(Vec::new())) {
if let Rule::RateLimiter(plans_vec) = rule {
return Some(plans_vec.to_vec());
Expand All @@ -70,42 +82,64 @@ impl Context for UpstreamCall {}

impl HttpContext for UpstreamCall {
fn on_http_request_headers(&mut self, _num_headers: usize) -> Action {
// Options
if let Some(method) = self.get_http_request_header(":method") {
if method == "OPTIONS" {
self.send_http_response(204, CORS_HEADERS.to_vec(), None);
return Action::Pause;
}
}

// Action for rule type: None
if let Some(_) = self.rule_is_none(self.get_http_request_header(":path").unwrap()) {
return Action::Continue;
}

// Action for rule type: RateLimiter
if let Some(plans_vec) =
self.rule_is_rate_limiter(self.get_http_request_header(":path").unwrap())
{
if let Some(header) = self.get_http_request_header("Authorization") {
// Decoding JWT token
if let Ok(token) = base64::decode(header) {
//Deserializing token
let obj: Data = serde_json::from_slice(&token).unwrap();

proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", obj).as_str())
.ok();

// Since the rate limit works on a rate per minute based quota, we find current time
let curr = self.get_current_time();
let tm = curr.duration_since(SystemTime::UNIX_EPOCH).unwrap();
let mn = (tm.as_secs() / 60) % 60;
let _sc = tm.as_secs() % 60;

// Initialise RateLimiter object
let mut rl = RateLimiter::get(&obj.username, &obj.plan);

// Initialising headers to send back
let mut headers = CORS_HEADERS.to_vec();
let count: String;

// Extracting limits based on plan stated in JWT token from the corresponding RateLimiterJson
let limit = plans_vec
.into_iter()
.filter(|x| x.identifier == obj.plan)
.map(|x| x.limit)
.collect::<Vec<u32>>()[0];
.collect::<Vec<u32>>();

// Checking if the appropriate plan exists
if limit.len() != 1 {
self.send_http_response(
429,
headers,
Some(b"Invalid plan name or duplicate plan names defined.\n"),
);
return Action::Pause;
}

if rl.update(mn as i32) > limit {
//Update request count in RateLimiter object, and check if it exceeds limits
if rl.update(mn as i32) > limit[0] {
count = rl.count.to_string();
headers
.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
Expand All @@ -115,6 +149,7 @@ impl HttpContext for UpstreamCall {
}
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?}", &rl).as_str())
.ok();
// set the new count in headers, and proxy_wasm storage
count = rl.count.to_string();
rl.set();
headers.append(&mut vec![("x-rate-limit", &count), ("x-app-user", &rl.key)]);
Expand All @@ -141,14 +176,18 @@ struct UpstreamCallRoot {
impl Context for UpstreamCallRoot {}
impl<'a> RootContext for UpstreamCallRoot {
//TODO: Revisit this once the read only feature is released in Istio 1.10
// Get Base64 encoded JSON from envoy config file when WASM VM starts
fn on_vm_start(&mut self, _: usize) -> bool {
if let Some(config_bytes) = self.get_configuration() {
// bytestring passed by VM -> String of base64 encoded JSON
let config_str = String::from_utf8(config_bytes).unwrap();
// String of base64 encoded JSON -> bytestring of decoded JSON
let config_b64 = base64::decode(config_str).unwrap();
// bytestring of decoded JSON -> String of decoded JSON
let json_str = String::from_utf8(config_b64).unwrap();

// Deserializing JSON String into vector of JsonPath objects
let json_vec: Vec<JsonPath> = serde_json::from_str(&json_str).unwrap();

// Creating HashMap of pattern ("path name", "rule type") and saving into UpstreamCallRoot object
for i in json_vec {
self.config_json.insert(i.name, i.rule);
}
Expand All @@ -157,6 +196,7 @@ impl<'a> RootContext for UpstreamCallRoot {
}

fn create_http_context(&self, _: u32) -> Option<Box<dyn HttpContext>> {
// creating UpstreamCall object for each new call
Some(Box::new(UpstreamCall::new(&self.config_json)))
}

Expand Down
40 changes: 7 additions & 33 deletions rate-limit-filter/src/rate_limiter/rate_limiter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,74 +5,48 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct RateLimiter {
//pub rpm: Option<u32>,
// Tracks time
pub min: i32,
// Tracks number of calls made
pub count: u32,
// stores a key(username according to example)
pub key: String,
}

impl RateLimiter {
fn new(key: &String, _plan: &String) -> Self {
// make these dynamic as well
/*
let limit = match plan.as_str() {
"Enterprise" => Some(100),
"Team" => Some(50),
"Personal" => Some(10),
_ => None,
};
*/
Self {
//rpm: limit,
min: -1,
count: 0,
key: key.clone(),
}
}
// Get key and plan from proxy_wasm shared data store (username+plan name)
pub fn get(key: &String, plan: &String) -> Self {
if let Ok(data) = proxy_wasm::hostcalls::get_shared_data(&key.clone()) {
if let Some(data) = data.0 {
let data: Option<Self> = bincode::deserialize(&data).unwrap_or(None);
if let Some(obj) = data {
/*
let limit = match plan.as_str() {
"Enterprise" => Some(100),
"Team" => Some(50),
"Personal" => Some(10),
_ => None,
};
obj.rpm = limit;
*/
return obj;
}
}
}
return Self::new(&key, &plan);
}
// Set key and plan in proxy_wasm shared data store (username+plan name)
pub fn set(&self) {
let target: Option<Self> = Some(self.clone());
let encoded: Vec<u8> = bincode::serialize(&target).unwrap();
proxy_wasm::hostcalls::set_shared_data(&self.key.clone(), Some(&encoded), None).ok();
}
// Update time (minute by minute) and increment count
pub fn update(&mut self, time: i32) -> u32 {
if self.min != time {
self.min = time;
self.count = 0;
}
self.count += 1;
proxy_wasm::hostcalls::log(
LogLevel::Debug,
format!("Obj {:?} ", self.count).as_str(), //{:?}", self.rpm
)
.ok();
proxy_wasm::hostcalls::log(LogLevel::Debug, format!("Obj {:?} ", self.count).as_str()).ok();
self.count
/*
if let Some(sm) = self.rpm {
if self.count > sm {
return false;
}
}
return true;
*/
}
}

0 comments on commit 7b4e81f

Please sign in to comment.