diff --git a/brilirs/src/basic_block.rs b/brilirs/src/basic_block.rs index 4845d8cd0..dbe47267c 100644 --- a/brilirs/src/basic_block.rs +++ b/brilirs/src/basic_block.rs @@ -1,40 +1,43 @@ use std::collections::HashMap; -// A program composed of basic blocks. -// (BB index of main program, list of BBs, mapping of label -> BB index) -pub type BBProgram = (Option, Vec, HashMap); +// A program represented as basic blocks. +pub struct BBProgram { + pub blocks: Vec, -#[derive(Debug)] -pub struct BasicBlock { - pub instrs: Vec, - pub exit: Vec, + // Map from label name to index in `blocks` of the block named by that label. + pub label_index: HashMap, + + // Map from function name to the index into `blocks` of the starting block of + // the function. + pub func_index: HashMap, } -impl BasicBlock { - fn new() -> BasicBlock { - BasicBlock { - instrs: Vec::new(), - exit: Vec::new(), +impl BBProgram { + pub fn new(prog: bril_rs::Program) -> BBProgram { + let mut bbprog = BBProgram { + blocks: vec![], + label_index: HashMap::new(), + func_index: HashMap::new(), + }; + for func in prog.functions { + bbprog.add_func_bbs(func); } + bbprog } -} - -pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { - let mut main_fn = None; - let mut blocks = Vec::new(); - let mut labels = HashMap::new(); - let mut bb_helper = |func: bril_rs::Function| -> usize { + fn add_func_bbs(&mut self, func: bril_rs::Function) -> usize { + self.func_index.insert(func.name.clone(), self.blocks.len()); let mut curr_block = BasicBlock::new(); - let root_block = blocks.len(); + let root_block = self.blocks.len(); let mut curr_label = None; + for instr in func.instrs.into_iter() { match instr { bril_rs::Code::Label { ref label } => { if !curr_block.instrs.is_empty() { - blocks.push(curr_block); + self.blocks.push(curr_block); if let Some(old_label) = curr_label { - labels.insert(old_label, blocks.len() - 1); + self.label_index.insert(old_label, self.blocks.len() - 1); } curr_block = BasicBlock::new(); } @@ -46,9 +49,9 @@ pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { || op == bril_rs::EffectOps::Return => { curr_block.instrs.push(instr); - blocks.push(curr_block); + self.blocks.push(curr_block); if let Some(l) = curr_label { - labels.insert(l, blocks.len() - 1); + self.label_index.insert(l, self.blocks.len() - 1); curr_label = None; } curr_block = BasicBlock::new(); @@ -60,22 +63,36 @@ pub fn find_basic_blocks(prog: bril_rs::Program) -> BBProgram { } if !curr_block.instrs.is_empty() { - blocks.push(curr_block); + // If we are here, the function ends without an explicit ret. To make + // processing easier, push a Return op onto the last block. + curr_block.instrs.push(RET.clone()); + self.blocks.push(curr_block); if let Some(l) = curr_label { - labels.insert(l, blocks.len() - 1); + self.label_index.insert(l, self.blocks.len() - 1); } } - root_block - }; + } +} + +#[derive(Debug)] +pub struct BasicBlock { + pub instrs: Vec, + pub exit: Vec, +} - for func in prog.functions.into_iter() { - let func_name = func.name.clone(); - let func_block = bb_helper(func); - if func_name == "main" { - main_fn = Some(func_block); +impl BasicBlock { + fn new() -> BasicBlock { + BasicBlock { + instrs: Vec::new(), + exit: Vec::new(), } } - - (main_fn, blocks, labels) } + +const RET: bril_rs::Code = bril_rs::Code::Instruction(bril_rs::Instruction::Effect { + op: bril_rs::EffectOps::Return, + args: vec![], + funcs: vec![], + labels: vec![], +}); diff --git a/brilirs/src/cfg.rs b/brilirs/src/cfg.rs index 9abc36591..205d4be32 100644 --- a/brilirs/src/cfg.rs +++ b/brilirs/src/cfg.rs @@ -1,12 +1,8 @@ -use crate::basic_block::BasicBlock; +use crate::basic_block::BBProgram; -use std::collections::HashMap; - -type CFG = Vec; - -pub fn build_cfg(mut blocks: Vec, label_to_block_idx: &HashMap) -> CFG { - let last_idx = blocks.len() - 1; - for (i, block) in blocks.iter_mut().enumerate() { +pub fn build_cfg(prog: &mut BBProgram) { + let last_idx = prog.blocks.len() - 1; + for (i, block) in prog.blocks.iter_mut().enumerate() { // If we're before the last block if i < last_idx { // Get the last instruction @@ -14,21 +10,14 @@ pub fn build_cfg(mut blocks: Vec, label_to_block_idx: &HashMap { - for l in labels { - block.exit.push(label_to_block_idx[l]); - } + if let bril_rs::EffectOps::Jump | bril_rs::EffectOps::Branch = op { + for l in labels { + block.exit.push(prog.label_index[l]); } - bril_rs::EffectOps::Return => {} - // TODO(yati): Do all effect ops end a BB? - _ => {} } } else { block.exit.push(i + 1); } } } - - blocks } diff --git a/brilirs/src/interp.rs b/brilirs/src/interp.rs index d2db88e9f..14d3a9366 100644 --- a/brilirs/src/interp.rs +++ b/brilirs/src/interp.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::convert::TryFrom; use std::fmt; +use std::iter::FromIterator; use crate::basic_block::{BBProgram, BasicBlock}; @@ -9,6 +10,8 @@ pub enum InterpError { BadJsonInt, BadJsonBool, NoMainFunction, + FuncNotFound(String), + NoRetValForfunc(String), BadNumArgs(usize, usize), // (expected, actual) BadNumLabels(usize, usize), // (expected, actual) VarNotFound(String), @@ -151,12 +154,15 @@ impl TryFrom<&Value> for f64 { } #[allow(clippy::float_cmp)] -fn execute_value_op( +fn execute_value_op( + prog: &BBProgram, op: &bril_rs::ValueOps, dest: &str, op_type: &bril_rs::Type, args: &[String], + funcs: &[String], value_store: &mut HashMap, + out: &mut W, ) -> Result<(), InterpError> { use bril_rs::ValueOps::*; match *op { @@ -271,7 +277,19 @@ fn execute_value_op( let args = get_args::(value_store, 2, args)?; value_store.insert(String::from(dest), Value::Bool(args[0] >= args[1])); } - Call => unreachable!(), // TODO(yati): Why is Call a ValueOp as well? + Call => { + assert!(funcs.len() == 1); + let vals = get_values(value_store, args.len(), args)?; + let vars: HashMap = + HashMap::from_iter(args.iter().cloned().zip(vals.into_iter().cloned())); + if let Some(val) = execute_func(&prog, &funcs[0], vars, out)? { + check_asmt_type(&val.get_type(), op_type)?; + value_store.insert(String::from(dest), val); + } else { + // This is a value-op call, so the target func must return a result. + return Err(InterpError::NoRetValForfunc(funcs[0].clone())); + } + } Phi | Alloc | Load | PtrAdd => unimplemented!(), } Ok(()) @@ -293,17 +311,29 @@ fn check_num_labels(expected: usize, labels: &[String]) -> Result<(), InterpErro } } -// Returns whether the program should continue running (i.e., if a Return was -// *not* executed). +// Result of executing an effect operation. +enum EffectResult { + // Return from the current function without any value. + Return, + + // Return a given value from the current function. + ReturnWithVal(Value), + + // Continue execution of the current function. + Continue, +} + fn execute_effect_op( + prog: &BBProgram, op: &bril_rs::EffectOps, args: &[String], labels: &[String], + funcs: &[String], curr_block: &BasicBlock, value_store: &HashMap, - mut out: T, + out: &mut T, next_block_idx: &mut Option, -) -> Result { +) -> Result { use bril_rs::EffectOps::*; match op { Jump => { @@ -320,36 +350,53 @@ fn execute_effect_op( Return => { out.flush().map_err(|e| InterpError::IoError(Box::new(e)))?; // NOTE: This only works so long as `main` is the only function - return Ok(false); + if args.is_empty() { + return Ok(EffectResult::Return); + } + let retval = value_store + .get(&args[0]) + .ok_or(InterpError::VarNotFound(args[0].clone()))?; + return Ok(EffectResult::ReturnWithVal(retval.clone())); } Print => { + let vals = get_values(value_store, args.len(), args)?; writeln!( out, "{}", - args + vals .iter() - .map(|a| format!("{}", value_store[a])) + .map(|v| format!("{}", v)) .collect::>() .join(", ") ) .map_err(|e| InterpError::IoError(Box::new(e)))?; } Nop => {} - Call => unreachable!(), + Call => { + assert!(funcs.len() == 1); + let vals = get_values(value_store, args.len(), args)?; + let vars: HashMap = + HashMap::from_iter(args.iter().cloned().zip(vals.into_iter().cloned())); + execute_func(&prog, &funcs[0], vars, out)?; + } Store | Free | Speculate | Commit | Guard => unimplemented!(), } - Ok(true) + Ok(EffectResult::Continue) } -pub fn execute(prog: BBProgram, mut out: T) -> Result<(), InterpError> { - let (main_fn, blocks, _labels) = prog; - let mut curr_block_idx: usize = main_fn.ok_or(InterpError::NoMainFunction)?; - - // Map from variable name to value. - let mut value_store: HashMap = HashMap::new(); +fn execute_func( + prog: &BBProgram, + func: &str, + mut vars: HashMap, + out: &mut T, +) -> Result, InterpError> { + let mut curr_block_idx = *prog + .func_index + .get(func) + .ok_or(InterpError::FuncNotFound(String::from(func)))?; loop { - let curr_block = &blocks[curr_block_idx]; + let curr_block = &prog.blocks[curr_block_idx]; let curr_instrs = &curr_block.instrs; let mut next_block_idx = if curr_block.exit.len() == 1 { Some(curr_block.exit[0]) @@ -367,44 +414,58 @@ pub fn execute(prog: BBProgram, mut out: T) -> Result<(), Int value, } => { check_asmt_type(const_type, &value.get_type())?; - value_store.insert(dest.clone(), Value::from(value)); + vars.insert(dest.clone(), Value::from(value)); } bril_rs::Instruction::Value { op, dest, op_type, args, + funcs, .. } => { - execute_value_op(op, dest, op_type, args, &mut value_store)?; + execute_value_op(&prog, op, dest, op_type, args, funcs, &mut vars, out)?; } bril_rs::Instruction::Effect { - op, args, labels, .. + op, + args, + labels, + funcs, + .. } => { - let should_continue = execute_effect_op( + match execute_effect_op( + prog, op, args, labels, + funcs, &curr_block, - &value_store, - &mut out, + &vars, + out, &mut next_block_idx, - )?; - - // TODO(yati): Correct only when main is the only function. - if !should_continue { - return Ok(()); - } + )? { + EffectResult::Continue => {} + EffectResult::Return => { + return Ok(None); + } + EffectResult::ReturnWithVal(val) => { + return Ok(Some(val)); + } + }; } } } } - if let Some(idx) = next_block_idx { curr_block_idx = idx; } else { out.flush().map_err(|e| InterpError::IoError(Box::new(e)))?; - return Ok(()); + return Ok(None); } } } + +pub fn execute(prog: BBProgram, out: &mut T) -> Result<(), InterpError> { + // Ignore return value of @main. + execute_func(&prog, "main", HashMap::new(), out).map(|_| ()) +} diff --git a/brilirs/src/lib.rs b/brilirs/src/lib.rs index a2f5aefbd..e1516a773 100644 --- a/brilirs/src/lib.rs +++ b/brilirs/src/lib.rs @@ -11,11 +11,11 @@ extern crate serde; extern crate serde_derive; extern crate serde_json; -pub fn run_input(input: Box, out: T) { +pub fn run_input(input: Box, mut out: T) { let prog = bril_rs::load_program_from_read(input); - let (main_idx, blocks, label_index) = basic_block::find_basic_blocks(prog); - let blocks = cfg::build_cfg(blocks, &label_index); - if let Err(e) = interp::execute((main_idx, blocks, label_index), out) { + let mut bbprog = basic_block::BBProgram::new(prog); + cfg::build_cfg(&mut bbprog); + if let Err(e) = interp::execute(bbprog, &mut out) { error!("{:?}", e); } } diff --git a/brilirs/testdata/call-with-args.json b/brilirs/testdata/call-with-args.json new file mode 100644 index 000000000..fc4f96e57 --- /dev/null +++ b/brilirs/testdata/call-with-args.json @@ -0,0 +1,88 @@ +{ + "functions": [ + { + "instrs": [ + { + "dest": "x", + "op": "const", + "type": "int", + "value": 2 + }, + { + "dest": "y", + "op": "const", + "type": "int", + "value": 2 + }, + { + "args": [ + "x", + "y" + ], + "dest": "z", + "funcs": [ + "add2" + ], + "op": "call", + "type": "int" + }, + { + "args": [ + "y" + ], + "op": "print" + }, + { + "args": [ + "z" + ], + "op": "print" + } + ], + "name": "main" + }, + { + "args": [ + { + "name": "x", + "type": "int" + }, + { + "name": "y", + "type": "int" + } + ], + "instrs": [ + { + "args": [ + "x", + "y" + ], + "dest": "w", + "op": "add", + "type": "int" + }, + { + "dest": "y", + "op": "const", + "type": "int", + "value": 5 + }, + { + "args": [ + "w" + ], + "op": "print" + }, + { + "args": [ + "w" + ], + "op": "ret" + } + ], + "name": "add2", + "type": "int" + } + ] +} diff --git a/brilirs/testdata/call.json b/brilirs/testdata/call.json new file mode 100644 index 000000000..35970122e --- /dev/null +++ b/brilirs/testdata/call.json @@ -0,0 +1,47 @@ +{ + "functions": [ + { + "instrs": [ + { + "dest": "v", + "op": "const", + "type": "int", + "value": 2 + }, + { + "funcs": [ + "print4" + ], + "op": "call" + }, + { + "args": [ + "v" + ], + "op": "print" + } + ], + "name": "main" + }, + { + "instrs": [ + { + "dest": "v", + "op": "const", + "type": "int", + "value": 4 + }, + { + "args": [ + "v" + ], + "op": "print" + }, + { + "op": "ret" + } + ], + "name": "print4" + } + ] +} diff --git a/brilirs/tests/interp_test.rs b/brilirs/tests/interp_test.rs index 93299bfbb..df9b5fa3b 100644 --- a/brilirs/tests/interp_test.rs +++ b/brilirs/tests/interp_test.rs @@ -56,4 +56,6 @@ interp_tests! { or: "./testdata/or.json", id: "./testdata/id.json", br: "./testdata/br.json", + call: "./testdata/call.json", + call_with_args: "./testdata/call-with-args.json", }