From 97c24329b9f1b5b3cdb7cb1f5d1d76ee316472e0 Mon Sep 17 00:00:00 2001 From: Vara Prasad Bandaru Date: Thu, 1 Feb 2024 21:45:45 +0530 Subject: [PATCH] Fix arbitrary_cpi lint and add documentation to it --- lints/arbitrary_cpi/README.md | 18 ++++ lints/arbitrary_cpi/src/lib.rs | 172 +++++++++++++++++++++++---------- 2 files changed, 141 insertions(+), 49 deletions(-) diff --git a/lints/arbitrary_cpi/README.md b/lints/arbitrary_cpi/README.md index e5d13a5..98cf801 100644 --- a/lints/arbitrary_cpi/README.md +++ b/lints/arbitrary_cpi/README.md @@ -39,3 +39,21 @@ let ix = Instruction { }; invoke(&ix, accounts.clone()); ``` + +**How the lint is implemented:** + +- For every function containing calls to `solana_program::program::invoke` +- find the definition of `Instruction` argument passed to `invoke`; first argument +- If the `Instruction` argument is result of a function call + - If the function is whitelisted, do not report; only functions defined in `spl_token::instruction` are whitelisted. + - Else report the call to `invoke` as vulnerable +- Else if the `Instruction` is initialized in the function itself + - find the assign statement assigning to the `program_id` field, assigning to field at `0`th index + - find all the aliases of `program_id`. Use the rhs of the assignment as initial alias and look for + all assignments assigning to the locals recursively. + - Check if `program_id` is compared using any of aliases. + - look for calls to `core::cmp::PartialEq{ne, eq}` where one of arg is moved from an alias. + - If one of the arg accesses `program_id`, check if the basic block containing the comparison + dominates the basic block containing call to `invoke` ensuring the `program_id` is checked in all execution + paths. + - If basic block does not dominate or there is no such comparison report the call to `invoke` diff --git a/lints/arbitrary_cpi/src/lib.rs b/lints/arbitrary_cpi/src/lib.rs index e5157a3..cc4859c 100644 --- a/lints/arbitrary_cpi/src/lib.rs +++ b/lints/arbitrary_cpi/src/lib.rs @@ -60,6 +60,24 @@ dylint_linting::declare_late_lint! { /// }; /// invoke(&ix, accounts.clone()); /// ``` + /// + /// **How the lint is implemented:** + /// + /// - For every function containing calls to `solana_program::program::invoke` + /// - find the definition of `Instruction` argument passed to `invoke`; first argument + /// - If the `Instruction` argument is result of a function call + /// - If the function is whitelisted, do not report; only functions defined in `spl_token::instruction` are whitelisted. + /// - Else report the call to `invoke` as vulnerable + /// - Else if the `Instruction` is initialized in the function itself + /// - find the assign statement assigning to the `program_id` field, assigning to field at `0`th index + /// - find all the aliases of `program_id`. Use the rhs of the assignment as initial alias and look for + /// all assignments assigning to the locals recursively. + /// - Check if `program_id` is compared using any of aliases. + /// - look for calls to `core::cmp::PartialEq{ne, eq}` where one of arg is moved from an alias. + /// - If one of the arg accesses `program_id`, check if the basic block containing the comparison + /// dominates the basic block containing call to `invoke` ensuring the `program_id` is checked in all execution + /// paths. + /// - If basic block does not dominate or there is no such comparison report the call to `invoke` pub ARBITRARY_CPI, Warn, "Finds unconstrained inter-contract calls" @@ -69,10 +87,13 @@ impl<'tcx> LateLintPass<'tcx> for ArbitraryCpi { fn check_body(&mut self, cx: &LateContext<'tcx>, body: &'tcx Body<'tcx>) { let hir_map = cx.tcx.hir(); let body_did = hir_map.body_owner_def_id(body.id()).to_def_id(); + // The body is the body of function whose mir is available + // fn_like includes fn, const fn, async fn but not closures. if !cx.tcx.def_kind(body_did).is_fn_like() || !cx.tcx.is_mir_available(body_did) { return; } let body_mir = cx.tcx.optimized_mir(body_did); + // list of block id and the terminator of the basic blocks in the CFG let terminators = body_mir .basic_blocks .iter_enumerated() @@ -80,6 +101,8 @@ impl<'tcx> LateLintPass<'tcx> for ArbitraryCpi { for (_idx, (block_id, terminator)) in terminators.enumerate() { if_chain! { if let t = terminator.as_ref().unwrap(); + // The terminator is a call to a function; the function is defined and is not function pointer or function object + // i.e The function is not copied or moved. Generic functions, trait methods are not Constant. if let TerminatorKind::Call { func: func_operand, args, @@ -90,13 +113,18 @@ impl<'tcx> LateLintPass<'tcx> for ArbitraryCpi { then { // Static call let callee_did = *def_id; + // Calls `invoke` if match_def_path(cx, callee_did, &paths::SOLANA_PROGRAM_INVOKE) { + // Get the `Instruction`, instruction is the first argument of `invoke` function. let inst_arg = &args[0]; if let Operand::Move(p) = inst_arg { + // Check if the Instruction is returned from a whitelisted function (is_whitelist = true) + // if `Instruction` is defined in this function, find all the locals/places the program_id is defined let (is_whitelist, programid_places) = Self::find_program_id_for_instru(cx, body_mir, block_id, p); let likely_programid_locals: Vec = programid_places.iter().map(|pl| pl.local).collect(); + // if not whitelisted, check if the program_id is compared using one of the locals. if !is_whitelist && !Self::is_programid_checked( cx, @@ -145,36 +173,45 @@ impl ArbitraryCpi { ) -> (bool, Vec>) { let preds = body.basic_blocks.predecessors(); let mut cur_block = block; - let mut found_program_id = false; - let mut likely_program_id_aliases = Vec::::new(); loop { // Walk the bb in reverse, starting with the terminator if let Some(t) = &body.basic_blocks[cur_block].terminator { + // the terminator is a call; the return value of the call is assigned to `inst_arg.local` match &t.kind { TerminatorKind::Call { func: mir::Operand::Constant(box func), destination: dest, args, .. - } if dest.local_or_deref_local() == Some(inst_arg.local) - && !found_program_id => + } if dest.local_or_deref_local() == Some(inst_arg.local) => { if_chain! { + // function definition if let TyKind::FnDef(def_id, _callee_substs) = func.literal.ty().kind(); + // non-zero args are passed in the call if !args.is_empty(); if let Operand::Copy(arg0_pl) | Operand::Move(arg0_pl) = &args[0]; then { // in order to trace back to the call which creates the // instruction, we have to trace through a call to Try::branch + // Expressions such as `call()?` with try operator will have Try::branch + // call and the first argument is the return value of actual call `call()`. + // If the call is Try::branch, look for the first arg which will have the return + // value of `call()`. if match_def_path(cx, *def_id, &paths::CORE_BRANCH) { inst_arg = arg0_pl; } else { + // If this is not Try::branch, check if its a call to a function in `spl_token::instruction` module let path = cx.get_def_path(*def_id); let token_path = paths::SPL_TOKEN_INSTRUCTION.map(Symbol::intern); + // if the instruction is constructed by a function in `spl_token::instruction`, assume program_id is checked if path.iter().take(2).eq(&token_path) { - return (true, likely_program_id_aliases); - } + return (true, Vec::new()); + } else { + // if the called function is not the whitelisted one, then we assume it to be vulnerable + return (false, Vec::new()); + } } } } @@ -184,60 +221,94 @@ impl ArbitraryCpi { } // check every statement for stmt in body.basic_blocks[cur_block].statements.iter().rev() { + // println!("5. {:?}, {:?}", stmt, stmt.kind); match &stmt.kind { - StatementKind::Assign(box (assign_place, rvalue)) - if assign_place.local_or_deref_local() - == inst_arg.local_or_deref_local() => - { - match rvalue { - Rvalue::Use( - Operand::Copy(rvalue_place) | Operand::Move(rvalue_place), - ) => { - // println!("Found assignment {:?}", stmt); - inst_arg = rvalue_place; - if found_program_id { - likely_program_id_aliases.push(*rvalue_place); - } - } - Rvalue::Ref(_, _, pl) => { - // println!("Found assignment (ref) {:?}", pl); - inst_arg = pl; - if found_program_id { - likely_program_id_aliases.push(*inst_arg); - } - } - _ => {} - } - } + // if the statement assigns to `inst_arg`, update `inst_arg` to the rhs StatementKind::Assign(box (assign_place, rvalue)) if assign_place.local == inst_arg.local => { + // println!("2. {:?}, {:?}, {:?}", assign_place, inst_arg, rvalue); + // Check if assign_place is assignment to a field. if not then this is not the initialization of the struct + // have to check further if_chain! { - // If we've found the Instruction that was passed to invoke, then - // field at index 0 will be the program_id if assign_place.projection.len() == 1; if let proj = assign_place.projection[0]; + // the projection could be deref etc if let ProjectionElem::Field(f, ty) = proj; - if f.index() == 0; - if let Some(adtdef) = ty.ty_adt_def(); - if match_def_path( - cx, - adtdef.did(), - &["solana_program", "pubkey", "Pubkey"], - ); then { - // We found the field + // stmt is an assignment to a field. + // there will be 3 statements(for 3 fields), ensure this statement is assignment + // to the first field `program_id` + // Also, do not update inst_arg, as this is just field assignment. + if_chain!{ + // program_id is the first field; index = 0 + if f.index() == 0; + if let Some(adtdef) = ty.ty_adt_def(); + if match_def_path( + cx, + adtdef.did(), + &["solana_program", "pubkey", "Pubkey"], + ); + then { + if let Rvalue::Use(Operand::Copy(pl) | Operand::Move(pl)) + | Rvalue::Ref(_, _, pl) = rvalue + { + // found the program_id. now look for all assignments/aliases to program_id. + let likely_program_id_aliases = Self::find_program_id_aliases(body, cur_block, pl); + return (false, likely_program_id_aliases); + } + } + } + } else { + // inst_arg is defined using this statement. rvalue could store the actual value. if let Rvalue::Use(Operand::Copy(pl) | Operand::Move(pl)) - | Rvalue::Ref(_, _, pl) = rvalue - { + | Rvalue::Ref(_, _, pl) = rvalue { inst_arg = pl; - likely_program_id_aliases.push(*pl); - // println!("Found program ID: {:?}", rvalue); - found_program_id = true; - break; + // println!("4. {:?}", inst_arg); } } - }; + } + } + _ => {} + } + } + match preds.get(cur_block) { + // traverse the CFG. Only predecessor is being considered. + Some(cur_preds) if !cur_preds.is_empty() => cur_block = cur_preds[0], + _ => { + break; + } + } + } + // we did not find the statement assigning to the program_id of `Instruction`. report as vulnerable + (false, Vec::new()) + } + + fn find_program_id_aliases<'tcx>( + body: &'tcx mir::Body<'tcx>, + block: BasicBlock, + mut id_arg: &Place<'tcx>, + ) -> Vec> { + let preds = body.basic_blocks.predecessors(); + let mut cur_block = block; + let mut likely_program_id_aliases = Vec::::new(); + likely_program_id_aliases.push(*id_arg); + loop { + // check every stmt + for stmt in body.basic_blocks[cur_block].statements.iter().rev() { + match &stmt.kind { + // if the statement assigns to `inst_arg`, update `inst_arg` to the rhs + StatementKind::Assign(box (assign_place, rvalue)) + if assign_place.local_or_deref_local() + == id_arg.local_or_deref_local() => + { + if let Rvalue::Use(Operand::Copy(pl) | Operand::Move(pl)) + | Rvalue::Ref(_, _, pl) = rvalue + { + id_arg = pl; + // println!("x. {:?}", pl); + likely_program_id_aliases.push(*pl); + } } _ => {} } @@ -249,8 +320,7 @@ impl ArbitraryCpi { } } } - // println!("Likely aliases: {:?}", likely_program_id_aliases); - (false, likely_program_id_aliases) + likely_program_id_aliases } // helper function @@ -269,6 +339,8 @@ impl ArbitraryCpi { return true; } } + // look for chain of assign statements whose value is eventually assigned to the `search_place` and + // see if any of the intermediate local is in the search_list. loop { for stmt in body.basic_blocks[cur_block].statements.iter().rev() { match &stmt.kind { @@ -318,6 +390,7 @@ impl ArbitraryCpi { loop { // check every statement if_chain! { + // is terminator a call `core::cmp::PartialEq{ne, eq}`? if let Some(t) = &body.basic_blocks[cur_block].terminator; if let TerminatorKind::Call { func: func_operand, @@ -328,6 +401,7 @@ impl ArbitraryCpi { if let TyKind::FnDef(def_id, _callee_substs) = func.literal.ty().kind(); if match_def_path(cx, *def_id, &["core", "cmp", "PartialEq", "ne"]) || match_def_path(cx, *def_id, &["core", "cmp", "PartialEq", "eq"]); + // check if any of the args accesses program_id if let Operand::Copy(arg0_pl) | Operand::Move(arg0_pl) = args[0]; if let Operand::Copy(arg1_pl) | Operand::Move(arg1_pl) = args[1]; then {