Skip to content

Commit

Permalink
Support external variables in canonic blocks.
Browse files Browse the repository at this point in the history
  • Loading branch information
ilyalesokhin-starkware committed Mar 9, 2025
1 parent a41643f commit b57538e
Show file tree
Hide file tree
Showing 60 changed files with 38,356 additions and 41,451 deletions.
25 changes: 14 additions & 11 deletions crates/cairo-lang-lowering/src/inline/test_data/inline
Original file line number Diff line number Diff line change
Expand Up @@ -875,32 +875,35 @@ End:

blk6:
Statements:
(v11: (core::integer::u128,)) <- struct_construct(v0)
(v12: core::panics::PanicResult::<(core::integer::u128,)>) <- PanicResult::Ok(v11)
End:
Goto(blk4, {v12 -> v3})
Goto(blk9, {v0 -> v11})

blk7:
Statements:
End:
Match(match core::integer::u128_overflowing_add(v1, v0) {
Result::Ok(v13) => blk8,
Result::Err(v14) => blk9,
Result::Ok(v12) => blk8,
Result::Err(v13) => blk10,
})

blk8:
Statements:
(v15: (core::integer::u128,)) <- struct_construct(v13)
(v16: core::panics::PanicResult::<(core::integer::u128,)>) <- PanicResult::Ok(v15)
End:
Goto(blk4, {v16 -> v3})
Goto(blk9, {v12 -> v11})

blk9:
Statements:
(v17: (core::panics::Panic, core::array::Array::<core::felt252>)) <- core::panic_with_const_felt252::<39878429859757942499084499860145094553463>()
(v18: core::panics::PanicResult::<(core::integer::u128,)>) <- PanicResult::Err(v17)
(v14: (core::integer::u128,)) <- struct_construct(v11)
(v15: core::panics::PanicResult::<(core::integer::u128,)>) <- PanicResult::Ok(v14)
End:
Goto(blk4, {v18 -> v3})
Goto(blk4, {v15 -> v3})

blk10:
Statements:
(v16: (core::panics::Panic, core::array::Array::<core::felt252>)) <- core::panic_with_const_felt252::<39878429859757942499084499860145094553463>()
(v17: core::panics::PanicResult::<(core::integer::u128,)>) <- PanicResult::Err(v16)
End:
Goto(blk4, {v17 -> v3})

//! > lowering_diagnostics

Expand Down
193 changes: 152 additions & 41 deletions crates/cairo-lang-lowering/src/optimizations/dedup_blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ use std::collections::HashMap;

use cairo_lang_semantic::items::constant::ConstValue;
use cairo_lang_semantic::{ConcreteVariant, TypeId};
use cairo_lang_utils::ordered_hash_map::{self, OrderedHashMap};
use cairo_lang_utils::unordered_hash_map::UnorderedHashMap;
use id_arena::Arena;
use itertools::Itertools;
use itertools::{Itertools, zip_eq};

use crate::ids::FunctionId;
use crate::utils::{Rebuilder, RebuilderEx};
use crate::{
BlockId, FlatBlock, FlatBlockEnd, FlatLowered, Statement, StatementCall, StatementConst,
StatementDesnap, StatementEnumConstruct, StatementSnapshot, StatementStructConstruct,
Expand All @@ -31,12 +33,7 @@ struct CanonicBlock {

/// A canonic representation of a variable in a canonic block.
#[derive(Hash, PartialEq, Eq)]
enum CanonicVar {
/// A variable that was defined outside of the block.
Global(usize),
/// A variable that was defined inside the block.
Local(usize),
}
struct CanonicVar(usize);

/// A canonic representation of a statement in a canonic block.
#[derive(Hash, PartialEq, Eq)]
Expand Down Expand Up @@ -79,24 +76,37 @@ struct CanonicBlockBuilder<'a> {
variable: &'a Arena<Variable>,
vars: UnorderedHashMap<VariableId, usize>,
types: Vec<TypeId>,
inputs: Vec<VarUsage>,
}

impl CanonicBlockBuilder<'_> {
fn new(variable: &Arena<Variable>) -> CanonicBlockBuilder<'_> {
CanonicBlockBuilder { variable, vars: Default::default(), types: vec![] }
CanonicBlockBuilder {
variable,
vars: Default::default(),
types: vec![],
inputs: Default::default(),
}
}

/// Converts an input var to a CanonicVar.
fn handle_input(&mut self, var_usage: &VarUsage) -> CanonicVar {
match self.vars.get(&var_usage.var_id) {
Some(local_idx) => CanonicVar::Local(*local_idx),
None => CanonicVar::Global(var_usage.var_id.index()),
}
let v = var_usage.var_id;

CanonicVar(match self.vars.entry(v) {
std::collections::hash_map::Entry::Occupied(e) => *e.get(),
std::collections::hash_map::Entry::Vacant(e) => {
self.types.push(self.variable[v].ty);
let new_id = *e.insert(self.types.len() - 1);
self.inputs.push(*var_usage);
new_id
}
})
}

/// Converts an output var to a CanonicVar.
fn handle_output(&mut self, v: &VariableId) -> CanonicVar {
CanonicVar::Local(match self.vars.entry(*v) {
CanonicVar(match self.vars.entry(*v) {
std::collections::hash_map::Entry::Occupied(e) => *e.get(),
std::collections::hash_map::Entry::Vacant(e) => {
self.types.push(self.variable[*v].ty);
Expand Down Expand Up @@ -158,8 +168,12 @@ impl CanonicBlockBuilder<'_> {

impl CanonicBlock {
/// Tries to create a canonic block from a flat block.
/// Return the canonic representation of the block and the external inputs used in the block.
/// Blocks that do not end in return do not have a canonic representation.
fn try_from_block(variable: &Arena<Variable>, block: &FlatBlock) -> Option<CanonicBlock> {
fn try_from_block(
variable: &Arena<Variable>,
block: &FlatBlock,
) -> Option<(CanonicBlock, Vec<VarUsage>)> {
let FlatBlockEnd::Return(returned_vars, _) = &block.end else {
return None;
};
Expand All @@ -179,65 +193,162 @@ impl CanonicBlock {

let returns = returned_vars.iter().map(|input| builder.handle_input(input)).collect();

Some(CanonicBlock { stmts, types: builder.types, returns })
Some((CanonicBlock { stmts, types: builder.types, returns }, builder.inputs))
}
}

struct VarRenamer<'a> {
variables: &'a mut Arena<Variable>,
pub renamed_vars: UnorderedHashMap<VariableId, VariableId>,
}

impl Rebuilder for VarRenamer<'_> {
fn map_var_id(&mut self, var: VariableId) -> VariableId {
match self.renamed_vars.entry(var) {
std::collections::hash_map::Entry::Occupied(e) => *e.get(),
std::collections::hash_map::Entry::Vacant(e) => {
*e.insert(self.variables.alloc(self.variables[var].clone()))
}
}
}

fn map_block_id(&mut self, block: BlockId) -> BlockId {
block
}
}

#[derive(Default)]
struct DedupContext {
/// Maps a CanonicBlock to a reference block that matches it.
canonic_blocks: OrderedHashMap<CanonicBlock, BlockId>,

/// Maps a block to the inputs that are needed for it to be shared,
block_id_to_inputs: HashMap<BlockId, Vec<VarUsage>>,
}

fn rebuild_block_and_inputs(
variables: &mut Arena<Variable>,
block: &FlatBlock,
inputs: &[VarUsage],
) -> (FlatBlock, Vec<VarUsage>) {
let new_inputs: Vec<VarUsage> = inputs
.iter()
.map(|var_usage| VarUsage {
var_id: variables.alloc(variables[var_usage.var_id].clone()),
location: var_usage.location,
})
.collect();

let mut renamer = VarRenamer {
variables,
renamed_vars: UnorderedHashMap::from_iter(zip_eq(
inputs.iter().map(|var_usage| var_usage.var_id),
new_inputs.iter().map(|var_usage| var_usage.var_id),
)),
};

(renamer.rebuild_block(block), new_inputs)
}

/// Deduplicates blocks by redirecting goto's and match arms to one of the duplicates.
/// The duplicate blocks will be remove later by `reorganize_blocks`.
pub fn dedup_blocks(lowered: &mut FlatLowered) {
if lowered.blocks.has_root().is_err() {
return;
}

let mut blocks: HashMap<CanonicBlock, BlockId> = Default::default();
let mut duplicates: HashMap<BlockId, BlockId> = Default::default();
let mut ctx = DedupContext::default();
// Maps duplicated blocks to the new shared block and the inputs that need to be remapped for
// the block.
let mut duplicates: UnorderedHashMap<BlockId, (BlockId, Vec<VarUsage>)> = Default::default();

let mut new_blocks = vec![];
let mut next_block_id = BlockId(lowered.blocks.len());

for (block_id, block) in lowered.blocks.iter() {
let Some(canonical_block) = CanonicBlock::try_from_block(&lowered.variables, block) else {
let Some((canonical_block, inputs)) =
CanonicBlock::try_from_block(&lowered.variables, block)
else {
continue;
};

let opt_mark_dup = match blocks.entry(canonical_block) {
std::collections::hash_map::Entry::Occupied(e) => {
duplicates.insert(block_id, *e.get());
Some(*e.get())
match ctx.canonic_blocks.entry(canonical_block) {
ordered_hash_map::Entry::Occupied(e) => {
let block_and_inputs = duplicates
.entry(*e.get())
.or_insert_with(|| {
let (block, new_inputs) =
rebuild_block_and_inputs(&mut lowered.variables, block, &inputs);
new_blocks.push(block);
let new_block_id = next_block_id;
next_block_id = next_block_id.next_block_id();

(new_block_id, new_inputs)
})
.clone();

duplicates.insert(block_id, block_and_inputs);
}
std::collections::hash_map::Entry::Vacant(e) => {
ordered_hash_map::Entry::Vacant(e) => {
e.insert(block_id);
None
}
};

if let Some(dup_block) = opt_mark_dup {
duplicates.entry(dup_block).or_insert_with(|| dup_block);
}
ctx.block_id_to_inputs.insert(block_id, inputs);
}

let mut new_blocks = vec![];
let mut next_block_id = BlockId(lowered.blocks.len());
let mut new_goto_block = |block_id, inputs: &Vec<VarUsage>, target_inputs: &Vec<VarUsage>| {
new_blocks.push(FlatBlock {
statements: vec![],
end: FlatBlockEnd::Goto(
block_id,
VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
target_inputs.iter().map(|var_usage| var_usage.var_id),
inputs.iter().cloned(),
)),
},
),
});

let new_block_id = next_block_id;
next_block_id = next_block_id.next_block_id();
new_block_id
};

// Note that the loop below can't be merged with the loop above as a block might be marked as
// dup after we already visiting an arm that goes to it.
// Note that the loop below cant be merged with the loop above as a block might be marked as dup
// after we already visiting an arm that goes to it.
for block in lowered.blocks.iter_mut() {
match &mut block.end {
FlatBlockEnd::Goto(target_block, remappings) if remappings.is_empty() => {
if let Some(block_id) = duplicates.get(target_block) {
*target_block = *block_id;
FlatBlockEnd::Goto(target_block, remappings) => {
let Some((block_id, target_inputs)) = duplicates.get(target_block) else {
continue;
};

let inputs = ctx.block_id_to_inputs.get(target_block).unwrap();
let mut inputs_remapping = VarRemapping {
remapping: OrderedHashMap::from_iter(zip_eq(
target_inputs.iter().map(|var_usage| var_usage.var_id),
inputs.iter().cloned(),
)),
};
for (_, src) in inputs_remapping.iter_mut() {
if let Some(src_before_remapping) = remappings.get(&src.var_id) {
*src = *src_before_remapping;
}
}

*target_block = *block_id;
*remappings = inputs_remapping;
}
FlatBlockEnd::Match { info } => {
for arm in info.arms_mut() {
let Some(block_id) = duplicates.get(&arm.block_id) else {
let Some((block_id, target_inputs)) = duplicates.get(&arm.block_id) else {
continue;
};
new_blocks.push(FlatBlock {
statements: vec![],
end: FlatBlockEnd::Goto(*block_id, VarRemapping::default()),
});

arm.block_id = next_block_id;
next_block_id = next_block_id.next_block_id();
let inputs = ctx.block_id_to_inputs.get(&arm.block_id).unwrap();
arm.block_id = new_goto_block(*block_id, inputs, target_inputs);
}
}
_ => {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,6 @@ impl<'a> Analyzer<'a> for MatchOptimizerContext {

if candidate.future_merge || candidate.additional_remappings.is_some() {
// TODO(ilya): Support multiple remappings with future merges.

// Revoke the candidate.
info.candidate = None;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ fn test_match_optimizer(
OptimizationPhase::ApplyInlining.apply(db, function_id, &mut before).unwrap();
OptimizationPhase::ReorganizeBlocks.apply(db, function_id, &mut before).unwrap();
OptimizationPhase::ReorderStatements.apply(db, function_id, &mut before).unwrap();
OptimizationPhase::ReorganizeBlocks.apply(db, function_id, &mut before).unwrap();

let mut after = before.clone();
OptimizationPhase::OptimizeMatches.apply(db, function_id, &mut after).unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ fn get_var_split(lowered: &mut FlatLowered) -> SplitMapping {
vars: stmt.inputs.iter().map(|input| input.var_id).collect_vec(),
},
)
.is_none()
.is_none(),
"{} appers twice",

Check warning on line 75 in crates/cairo-lang-lowering/src/optimizations/split_structs.rs

View workflow job for this annotation

GitHub Actions / typos

"appers" should be "appears".
stmt.output.index()
);
}
}
Expand Down
13 changes: 11 additions & 2 deletions crates/cairo-lang-lowering/src/optimizations/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,18 @@ impl OptimizationPhase {
OptimizationPhase::OptimizeMatches => optimize_matches(lowered),
OptimizationPhase::OptimizeRemappings => optimize_remappings(lowered),
OptimizationPhase::ReorderStatements => reorder_statements(db, lowered),
OptimizationPhase::ReorganizeBlocks => reorganize_blocks(lowered),
OptimizationPhase::ReorganizeBlocks => {
// let lowered_formatter = LoweredFormatter::new(db.upcast(), &lowered.variables);
// println!("{:?}", lowered.debug(&lowered_formatter));
reorganize_blocks(lowered)
}
OptimizationPhase::ReturnOptimization => return_optimization(db, lowered),
OptimizationPhase::SplitStructs => split_structs(lowered),
OptimizationPhase::SplitStructs => {
// let lowered_formatter = LoweredFormatter::new(db.upcast(),
// &lowered_function.variables); println!("{:?}",
// lowered_function.debug(&lowered_formatter));
split_structs(lowered)
}
OptimizationPhase::LowerImplicits => lower_implicits(db, function, lowered),
OptimizationPhase::GasRedeposit => gas_redeposit(db, function, lowered),
OptimizationPhase::Validate => validate(lowered)
Expand Down
Loading

0 comments on commit b57538e

Please sign in to comment.