diff --git a/bedrock2/src/bedrock2/MetricCosts.v b/bedrock2/src/bedrock2/MetricCosts.v new file mode 100644 index 000000000..83d0e9e36 --- /dev/null +++ b/bedrock2/src/bedrock2/MetricCosts.v @@ -0,0 +1,153 @@ +Require Import BinIntDef. +Require Import Coq.Strings.String. +Require Import bedrock2.MetricLogging. +From coqutil.Tactics Require Import destr. + +Local Open Scope MetricH_scope. + +Inductive compphase: Type := +| PreSpill +| PostSpill. + +Section FlatImpExec. + + Context {varname: Type}. + Variable (phase: compphase). + Variable (isReg: varname -> bool). + + Definition cost_interact mc := + match phase with + | PreSpill => mkMetricLog 100 100 100 100 + | PostSpill => mkMetricLog 50 50 50 50 + end + mc. + + Definition cost_call mc := + match phase with + | PreSpill => mkMetricLog 200 200 200 200 + | PostSpill => mkMetricLog 100 100 100 100 + end + mc. + + (* TODO think about a non-fixed bound on the cost of function preamble and postamble *) + + Definition cost_load x a mc := + match (isReg x, isReg a) with + | (false, false) => mkMetricLog 3 1 5 0 + | (false, true) => mkMetricLog 2 1 3 0 + | ( true, false) => mkMetricLog 2 0 4 0 + | ( true, true) => mkMetricLog 1 0 2 0 + end + mc. + + Definition cost_store a v mc := + match (isReg a, isReg v) with + | (false, false) => mkMetricLog 3 1 5 0 + | (false, true) => mkMetricLog 2 1 3 0 + | ( true, false) => mkMetricLog 2 1 3 0 + | ( true, true) => mkMetricLog 1 1 1 0 + end + mc. + + Definition cost_inlinetable x i mc := + match (isReg x, isReg i) with + | (false, false) => mkMetricLog 5 1 7 1 + | (false, true) => mkMetricLog 4 1 5 1 + | ( true, false) => mkMetricLog 4 0 6 1 + | ( true, true) => mkMetricLog 3 0 4 1 + end + mc. + + Definition cost_stackalloc x mc := + match isReg x with + | false => mkMetricLog 2 1 2 0 + | true => mkMetricLog 1 0 1 0 + end + mc. + + Definition cost_lit x mc := + match isReg x with + | false => mkMetricLog 9 1 9 0 + | true => mkMetricLog 8 0 8 0 + end + mc. + + Definition cost_op x y z mc := + match (isReg x, isReg y, isReg z) with + | (false, false, false) => mkMetricLog 5 1 7 0 + | (false, false, true) | (false, true, false) => mkMetricLog 4 1 5 0 + | (false, true, true) => mkMetricLog 3 1 3 0 + | ( true, false, false) => mkMetricLog 4 0 6 0 + | ( true, false, true) | ( true, true, false) => mkMetricLog 3 0 4 0 + | ( true, true, true) => mkMetricLog 2 0 2 0 + end + mc. + + Definition cost_set x y mc := + match (isReg x, isReg y) with + | (false, false) => mkMetricLog 3 1 4 0 + | (false, true) => mkMetricLog 2 1 2 0 + | ( true, false) => mkMetricLog 2 0 3 0 + | ( true, true) => mkMetricLog 1 0 1 0 + end + mc. + + Definition cost_if x y mc := + match (isReg x, match y with | Some y' => isReg y' | None => true end) with + | (false, false) => mkMetricLog 4 0 6 1 + | (false, true) | ( true, false) => mkMetricLog 3 0 4 1 + | ( true, true) => mkMetricLog 2 0 2 1 + end + mc. + + Definition cost_loop_true x y mc := + match (isReg x, match y with | Some y' => isReg y' | None => true end) with + | (false, false) => mkMetricLog 4 0 6 1 + | (false, true) | ( true, false) => mkMetricLog 3 0 4 1 + | ( true, true) => mkMetricLog 2 0 2 1 + end + mc. + + Definition cost_loop_false x y mc := + match (isReg x, match y with | Some y' => isReg y' | None => true end) with + | (false, false) => mkMetricLog 3 0 5 1 + | (false, true) | ( true, false) => mkMetricLog 2 0 3 1 + | ( true, true) => mkMetricLog 1 0 1 1 + end + mc. + +End FlatImpExec. + +Definition isRegZ (var : Z) : bool := + Z.leb var 31. + +Definition isRegStr (var : String.string) : bool := + String.prefix "reg_" var. + +(* awkward tactic use to avoid Qed slowness *) +(* this is slow with (eq_refl t) and fast with (eq_refl t') due to black box heuristics *) +Ltac cost_unfold := + repeat ( + let H := match goal with + | H : context[cost_interact] |- _ => H + | H : context[cost_call] |- _ => H + | H : context[cost_load] |- _ => H + | H : context[cost_store] |- _ => H + | H : context[cost_inlinetable] |- _ => H + | H : context[cost_stackalloc] |- _ => H + | H : context[cost_lit] |- _ => H + | H : context[cost_op] |- _ => H + | H : context[cost_set] |- _ => H + | H : context[cost_if] |- _ => H + | H : context[cost_loop_true] |- _ => H + | H : context[cost_loop_false] |- _ => H + end in + let t := type of H in + let t' := eval cbv [cost_interact cost_call cost_load cost_store + cost_inlinetable cost_stackalloc cost_lit cost_op cost_set + cost_if cost_loop_true cost_loop_false] in t in + replace t with t' in H by (exact (eq_refl t')) + ); + cbv [cost_interact cost_call cost_load cost_store cost_inlinetable + cost_stackalloc cost_lit cost_op cost_set cost_if cost_loop_true + cost_loop_false]; + unfold EmptyMetricLog in *. + +Ltac cost_destr := + repeat match goal with + | x : compphase |- _ => destr x + | _ : context[if ?x then _ else _] |- _ => destr x; try discriminate + | |- context[if ?x then _ else _] => destr x; try discriminate + end. + +Ltac cost_solve := cost_unfold; cost_destr; try solve_MetricLog. +Ltac cost_solve_piecewise := cost_unfold; cost_destr; try solve_MetricLog_piecewise. +Ltac cost_hammer := try solve [eauto 3 with metric_arith | cost_solve]. diff --git a/bedrock2/src/bedrock2/MetricLogging.v b/bedrock2/src/bedrock2/MetricLogging.v index 451dacc70..b8083f747 100644 --- a/bedrock2/src/bedrock2/MetricLogging.v +++ b/bedrock2/src/bedrock2/MetricLogging.v @@ -29,19 +29,29 @@ Section Riscv. Definition subMetricLoads n log := withLoads (loads log - n) log. Definition subMetricJumps n log := withJumps (jumps log - n) log. - Definition metricSub(metric: MetricLog -> Z) finalM initialM : Z := - Z.sub (metric finalM) (metric initialM). + Definition metricAdd(metric: MetricLog -> Z) m1 m2 : Z := + Z.add (metric m1) (metric m2). + Definition metricSub(metric: MetricLog -> Z) m1 m2 : Z := + Z.sub (metric m1) (metric m2). Definition metricsOp op : MetricLog -> MetricLog -> MetricLog := - fun initialM finalM => + fun m1 m2 => mkMetricLog - (op instructions initialM finalM) - (op stores initialM finalM) - (op loads initialM finalM) - (op jumps initialM finalM). + (op instructions m1 m2) + (op stores m1 m2) + (op loads m1 m2) + (op jumps m1 m2). + Definition metricsAdd := metricsOp metricAdd. Definition metricsSub := metricsOp metricSub. + Definition metricsMul (n : Z) (m : MetricLog) := + mkMetricLog + (n * instructions m) + (n * stores m) + (n * loads m) + (n * jumps m). + Definition metricLeq(metric: MetricLog -> Z) m1 m2: Prop := (metric m1) <= (metric m2). @@ -51,18 +61,15 @@ Section Riscv. metricLeq loads m1 m2 /\ metricLeq jumps m1 m2. - Definition metricMax(metric: MetricLog -> Z) m1 m2: Z := - Z.max (metric m1) (metric m2). - - Definition metricsMax := metricsOp metricMax. End Riscv. -Declare Scope MetricH_scope. Bind Scope MetricH_scope with MetricLog. Delimit Scope MetricH_scope with metricsH. Infix "<=" := metricsLeq : MetricH_scope. +Infix "+" := metricsAdd : MetricH_scope. Infix "-" := metricsSub : MetricH_scope. +Infix "*" := metricsMul : MetricH_scope. #[export] Hint Unfold withInstructions @@ -78,8 +85,11 @@ Infix "-" := metricsSub : MetricH_scope. subMetricStores subMetricJumps metricsOp + metricAdd + metricsAdd metricSub metricsSub + metricsMul metricLeq metricsLeq : unf_metric_log. @@ -103,7 +113,66 @@ Ltac fold_MetricLog := Ltac simpl_MetricLog := cbn [instructions loads stores jumps] in *. +(* need this to define solve_MetricLog, but need solve_MetricLog inside of MetricArith, oops *) +Lemma add_assoc' : forall n m p, (n + (m + p) = n + m + p)%metricsH. +Proof. intros. unfold_MetricLog. f_equal; apply Z.add_assoc. Qed. + +Lemma metriclit : forall a b c d a' b' c' d' mc, + metricsAdd (mkMetricLog a b c d) (metricsAdd (mkMetricLog a' b' c' d') mc) = + metricsAdd (mkMetricLog (a+a') (b+b') (c+c') (d+d')) mc. +Proof. intros. rewrite add_assoc'. reflexivity. Qed. + +Ltac flatten_MetricLog := repeat rewrite metriclit in *. + Ltac solve_MetricLog := + flatten_MetricLog; repeat unfold_MetricLog; repeat simpl_MetricLog; blia. + +Ltac solve_MetricLog_piecewise := + flatten_MetricLog; + repeat unfold_MetricLog; + repeat simpl_MetricLog; + f_equal; blia. + +Module MetricArith. + + Open Scope MetricH_scope. + + Lemma mul_sub_distr_r : forall n m p, (n - m) * p = n * p - m * p. + Proof. intros. unfold_MetricLog. f_equal; apply Z.mul_sub_distr_r. Qed. + + Lemma add_sub_swap : forall n m p, n + m - p = n - p + m. + Proof. intros. unfold_MetricLog. f_equal; apply Z.add_sub_swap. Qed. + + Lemma le_add_le_sub_r : forall n m p, n + p <= m <-> n <= m - p. + Proof. solve_MetricLog. Qed. + + Lemma le_trans : forall n m p, n <= m -> m <= p -> n <= p. + Proof. solve_MetricLog. Qed. + + Lemma le_refl : forall m, m <= m. + Proof. solve_MetricLog. Qed. + + Lemma le_sub_mono : forall n m p, n - p <= m - p <-> n <= m. + Proof. solve_MetricLog. Qed. + + Lemma add_0_r : forall mc, (mc + EmptyMetricLog)%metricsH = mc. + Proof. destruct mc. unfold EmptyMetricLog. solve_MetricLog_piecewise. Qed. + + Lemma sub_0_r : forall mc, (mc - EmptyMetricLog)%metricsH = mc. + Proof. destruct mc. unfold EmptyMetricLog. solve_MetricLog_piecewise. Qed. + + Lemma add_comm : forall n m, (n + m = m + n)%metricsH. + Proof. intros. unfold_MetricLog. f_equal; apply Z.add_comm. Qed. + + Lemma add_assoc : forall n m p, (n + (m + p) = n + m + p)%metricsH. + Proof. intros. unfold_MetricLog. f_equal; apply Z.add_assoc. Qed. + +End MetricArith. + +Create HintDb metric_arith. +#[export] Hint Resolve MetricArith.le_trans MetricArith.le_refl MetricArith.add_0_r MetricArith.sub_0_r MetricArith.add_comm MetricArith.add_assoc : metric_arith. +#[export] Hint Resolve <- MetricArith.le_sub_mono : metric_arith. +#[export] Hint Resolve -> MetricArith.le_sub_mono : metric_arith. diff --git a/bedrock2/src/bedrock2/MetricLoops.v b/bedrock2/src/bedrock2/MetricLoops.v new file mode 100644 index 000000000..8517cd9eb --- /dev/null +++ b/bedrock2/src/bedrock2/MetricLoops.v @@ -0,0 +1,610 @@ +Require Import coqutil.Datatypes.PrimitivePair coqutil.Datatypes.HList coqutil.dlet. +Require Import Coq.Classes.Morphisms BinIntDef. +Require Import coqutil.Macros.unique coqutil.Map.Interface coqutil.Word.Interface. Import map. +Require Import coqutil.Word.Bitwidth. +Require Import coqutil.Map.Properties. +Require Import coqutil.Tactics.destr. +From bedrock2 Require Import Map.Separation Map.SeparationLogic. +From bedrock2 Require Import Syntax MetricSemantics Markers. +From bedrock2 Require Semantics. +From bedrock2 Require Import MetricWeakestPrecondition MetricWeakestPreconditionProperties. + +Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. + +Section Loops. + Context {width: Z} {BW: Bitwidth width} {word: word.word width} {mem: map.map word Byte.byte}. + Context {locals: map.map String.string word}. + Context {ext_spec: Semantics.ExtSpec}. + Context {word_ok : word.ok word} {mem_ok : map.ok mem}. + Context {locals_ok : map.ok locals}. + Context {ext_spec_ok : Semantics.ext_spec.ok ext_spec}. + + Context {fs : Semantics.env}. + Let call := fs. + + Local Notation UNK := String.EmptyString. + + Lemma wp_while: forall e c t m l mc (post: _ -> _ -> _ -> _ -> Prop), + (exists measure (lt:measure->measure->Prop) (inv:measure->Semantics.trace->mem->locals->MetricLog->Prop), + Coq.Init.Wf.well_founded lt /\ + (exists v, inv v t m l mc) /\ + (forall v t m l mc, inv v t m l mc -> + exists bv bmc, dexpr m l e mc (bv, bmc) /\ + (word.unsigned bv <> 0%Z -> cmd call c t m l bmc (fun t' m' l' mc' => + exists v', inv v' t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc') + /\ lt v' v)) /\ + (word.unsigned bv = 0%Z -> post t m l + (cost_loop_false isRegStr UNK (Some UNK) bmc)))) -> + cmd call (cmd.while e c) t m l mc post. + Proof. + intros. destruct H as (measure & lt & inv & Hwf & HInit & Hbody). + destruct HInit as (v0 & HInit). + revert t m l mc HInit. pattern v0. revert v0. + eapply (well_founded_ind Hwf). intros. + specialize Hbody with (1 := HInit) (mc := mc). destruct Hbody as (bv & bmc & Hb & Ht & Hf). + eapply expr_sound in Hb. destruct Hb as (bv' & bmc' & Hb & Heq). inversion Heq. subst bv' bmc'. + destr.destr (Z.eqb (word.unsigned bv) 0). + - specialize Hf with (1 := E). eapply exec.while_false; try eassumption. + - specialize Ht with (1 := E). eapply sound_cmd in Ht. + eapply exec.while_true; eauto. + cbv beta. intros * (v' & HInv & HLt). eapply sound_cmd. eauto. + Qed. + + Lemma tailrec_localsmap_1ghost + {e c t} {m: mem} {l} {mc} {post : Semantics.trace -> mem -> locals -> MetricLog -> Prop} + {measure: Type} {Ghost: Type} + (P Q: measure -> Ghost -> Semantics.trace -> mem -> locals -> MetricLog -> Prop) + (lt: measure -> measure -> Prop) + (Hwf: well_founded lt) + (v0: measure) (g0: Ghost) + (Hpre: P v0 g0 t m l mc) + (Hbody: forall v g t m l mc, + P v g t m l mc -> + exists brv brmc, expr m l e mc (eq (brv, brmc)) /\ + (word.unsigned brv <> 0%Z -> cmd call c t m l brmc + (fun t' m' l' mc' => exists v' g', + P v' g' t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc') /\ + lt v' v /\ + (forall t'' m'' l'' mc'', Q v' g' t'' m'' l'' mc'' -> Q v g t'' m'' l'' mc''))) /\ + (word.unsigned brv = 0%Z -> Q v g t m l + (cost_loop_false isRegStr UNK (Some UNK) brmc))) + (Hpost: forall t m l mc, Q v0 g0 t m l mc -> post t m l mc) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists measure, lt, (fun v t m l mc => + exists g, P v g t m l mc /\ forall t'' m'' l'' mc'', Q v g t'' m'' l'' mc'' -> Q v0 g0 t'' m'' l'' mc''). + split; [assumption|]. + split; [solve[eauto]|]. + intros vi ti mi li mci (gi & HPi & HQimpl). + specialize (Hbody vi gi ti mi li mci HPi). + destruct Hbody as (brv & brmc & ? & Hbody). exists brv, brmc; split; [assumption|]. + destruct Hbody as (Htrue & Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lj mcj (vj& gj & HPj & Hlt & Qji); eauto 9. } + { eauto. } + Qed. + + Lemma tailrec_localsmap_1ghost_parameterized_finalpost + {e c rest t} {m: mem} {l mc} + {measure: Type} {Ghost: Type} + (P Q: measure -> Ghost -> Semantics.trace -> mem -> locals -> MetricLog -> Prop) + (lt: measure -> measure -> Prop) + (Hwf: well_founded lt) + (v0: measure) (g0: Ghost) + (Hpre: P v0 g0 t m l mc) + (Hbody: forall v g t m l mc, + P v g t m l mc -> + exists brv brmc, expr m l e mc (eq (brv, brmc)) /\ + (word.unsigned brv <> 0%Z -> cmd call c t m l brmc + (fun t' m' l' mc' => exists v' g', + P v' g' t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc') /\ + lt v' v /\ + (forall t'' m'' l'' mc'', Q v' g' t'' m'' l'' mc'' -> Q v g t'' m'' l'' mc''))) /\ + (word.unsigned brv = 0%Z -> cmd call rest t m l + (cost_loop_false isRegStr UNK (Some UNK) brmc) (Q v g))) + : cmd call (cmd.seq (cmd.while e c) rest) t m l mc (Q v0 g0). + Proof. + cbn. eapply tailrec_localsmap_1ghost with + (Q := fun v g t m l mc => cmd call rest t m l mc (Q v g)). + 1: eassumption. + 1: exact Hpre. + 2: intros *; exact id. + intros vi gi ti mi li mci HPi. + specialize (Hbody vi gi ti mi li mci HPi). + destruct Hbody as (brv & brmc & ? & Hbody). exists brv, brmc; split; [assumption|]. + destruct Hbody as (Htrue & Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lj mcj (vj& gj & HPj & Hlt & Qji). do 2 eexists. + split. 1: eassumption. split. 1: assumption. + intros. + eapply Proper_cmd. + 2: eassumption. + intros tk mk lk mck HH. eapply Qji. assumption. } + eapply Hpc. + Qed. + + (* marking logical connectives with the source file they were used in for limiting unfolding *) + Local Notation "A /\ B" := (Markers.split (A /\ B)). + Local Notation "A /\ B" := (Markers.split (A /\ B)) : type_scope. + + (* shallow reflection for resolving map accesses during symbolic execution *) + (* each lemma below is duplicated for various levels of use of this trick *) + Definition reconstruct (variables:list String.string) (values:tuple word (length variables)) : locals := + map.putmany_of_tuple (tuple.of_list variables) values map.empty. + Fixpoint gather (variables : list String.string) (l : locals) : option (locals * tuple word (length variables)) := + match variables with + | nil => Some (l, tt) + | cons x xs' => + match map.get l x with + | None => None + | Some v => + match gather xs' (map.remove l x) with + | None => None + | Some (l, vs') => Some (l, (pair.mk v vs')) + end + end + end. + + Lemma putmany_gather ks vs m me (H : gather ks m = Some (me, vs)) : + map.putmany_of_tuple (tuple.of_list ks) vs me = m. + Proof. + revert H; revert me; revert m; revert vs; induction ks; cbn [gather map.putmany_of_list]; intros. + { inversion H; subst. exact eq_refl. } + repeat match type of H with context[match ?x with _ => _ end] => destruct x eqn:? end; + repeat (match goal with H : _ |- _ => eapply IHks in H end); inversion H; subst; clear H. + cbn [map.putmany_of_tuple tuple.of_list length]. + match goal with H : _ |- _ => rewrite H; clear H end. + assert (map.get m a = Some r -> put (remove m a) a r = m). { + intro A. + apply map_ext. + intro k. + erewrite map.get_put_dec. + destr (String.eqb a k); try congruence. + rewrite map.get_remove_diff; congruence. + } + auto. + Qed. + + Definition enforce (variables : list String.string) (values:tuple word (length variables)) (l:locals) : Prop := + match gather variables l with + | None => False + | Some (remaining, r) => values = r /\ remaining = map.empty + end. + Lemma reconstruct_enforce variables ll lm (H : enforce variables ll lm) : lm = reconstruct variables ll. + progress cbv [enforce] in H. + repeat match type of H with context[match ?x with _ => _ end] => destruct x eqn:? end; + destruct H; subst. + symmetry. eapply putmany_gather. assumption. + Qed. + + Lemma hlist_forall_foralls: forall (argts : polymorphic_list.list Type) (P : hlist argts -> Prop), (forall x : hlist argts, P x) -> hlist.foralls P. + Proof. induction argts; cbn; auto. Qed. + + Import pair. + + Lemma while_localsmap + {e c t l mc} {m : mem} + {measure : Type} (invariant:_->_->_->_->_->Prop) + {lt} (Hwf : well_founded lt) (v0 : measure) + {post : _->_->_->_-> Prop} + (Hpre : invariant v0 t m l mc) + (Hbody : forall v t m l mc, + invariant v t m l mc -> + exists brv brmc, expr m l e mc (eq (Datatypes.pair brv brmc)) /\ + (word.unsigned brv <> 0 -> + cmd fs c t m l brmc (fun t m l mc => exists v', + invariant v' t m l (cost_loop_true isRegStr UNK (Some UNK) mc) + /\ lt v' v)) /\ + (word.unsigned brv = 0 -> post t m l + (cost_loop_false isRegStr UNK (Some UNK) brmc))) + : cmd fs (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists measure, lt, invariant. + split. 1: exact Hwf. + split. 1: eauto. + exact Hbody. + Qed. + + Lemma while + {e c t l mc} {m : mem} + (variables : list String.string) + {localstuple : tuple word (length variables)} + {measure : Type} (invariant:_->_->_->_-> ufunc word (length variables) Prop) + {lt} (Hwf : well_founded lt) (v0 : measure) + {post : _->_->_-> _ -> Prop} + (Pl : enforce variables localstuple l) + (Hpre : tuple.apply (invariant v0 t m mc) localstuple) + (Hbody : forall v t m mc, tuple.foralls (fun localstuple => + tuple.apply (invariant v t m mc) localstuple -> + let l := reconstruct variables localstuple in + exists brv brmc, expr m l e mc (eq (Datatypes.pair brv brmc)) /\ + (word.unsigned brv <> 0 -> + cmd call c t m l brmc (fun t m l mc => + Markers.unique (Markers.left (tuple.existss (fun localstuple => + enforce variables localstuple l /\ + Markers.right (Markers.unique (exists v', + tuple.apply (invariant v' t m (cost_loop_true isRegStr UNK (Some UNK) mc)) localstuple /\ lt v' v))))))) /\ + (word.unsigned brv = 0 -> post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc)))) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply (while_localsmap (fun v t m l mc => + exists localstuple, enforce variables localstuple l /\ + tuple.apply (invariant v t m mc) localstuple)); + unfold Markers.split; eauto. + intros vi ti mi li mci (?&X&Y). + specialize (Hbody vi ti mi mci). + eapply hlist.foralls_forall in Hbody. + specialize (Hbody Y). + rewrite <-(reconstruct_enforce _ _ _ X) in Hbody. + destruct Hbody as (brv & brmc & Cond & Again & Done). + exists brv. exists brmc. split; [exact Cond|]. + + split; [|exact Done]. + intro NE. specialize (Again NE). + eapply Proper_cmd; [ |eapply Again]. + cbv [Morphisms.pointwise_relation Basics.impl Markers.right Markers.unique Markers.left] in *. + intros t' m' l' mc' Ex. + eapply hlist.existss_exists in Ex. cbv beta in Ex. destruct Ex as (ls & E & v' & Inv' & LT). + eauto. + Qed. + + Lemma tailrec + {e c t localsmap mc} {m : mem} + (ghosttypes : polymorphic_list.list Type) + (variables : list String.string) + {l0 : tuple word (length variables)} + {Pl : enforce variables l0 localsmap} + {post : _->_->_->_-> Prop} + {measure : Type} (spec:_->HList.arrows ghosttypes (_->_->ufunc word (length variables) (MetricLog -> Prop*(_->_->ufunc word (length variables) (MetricLog -> Prop))))) lt + (Hwf : well_founded lt) + (v0 : measure) + : hlist.foralls (fun (g0 : hlist ghosttypes) => forall + (Hpre : (tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(1)) + + (Hbody : forall v, hlist.foralls (fun g => forall t m mc, tuple.foralls (fun l => + @dlet _ (fun _ => Prop) (reconstruct variables l) (fun localsmap : locals => + match tuple.apply (hlist.apply (spec v) g t m) l mc with S_ => + S_.(1) -> + Markers.unique (Markers.left (exists brv brmc, expr m localsmap e mc (eq (Datatypes.pair brv brmc)) /\ Markers.right ( + (word.unsigned brv <> 0%Z -> cmd call c t m localsmap brmc + (fun t' m' localsmap' mc' => + Markers.unique (Markers.left (hlist.existss (fun l' => enforce variables l' localsmap' /\ Markers.right ( + Markers.unique (Markers.left (hlist.existss (fun g' => exists v', + match tuple.apply (hlist.apply (spec v') g' t' m') l' (cost_loop_true isRegStr UNK (Some UNK) mc') with S' => + S'.(1) /\ Markers.right ( + lt v' v /\ + forall T M, hlist.foralls (fun L => forall MC, tuple.apply (S'.(2) T M) L MC -> tuple.apply (S_.(2) T M) L MC)) end))))))))) /\ + (word.unsigned brv = 0%Z -> tuple.apply (S_.(2) t m) l (cost_loop_false isRegStr UNK (Some UNK) brmc)))))end)))) + (Hpost : match (tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(2) with Q0 => forall t m mc, hlist.foralls (fun l => tuple.apply (Q0 t m) l mc -> post t m (reconstruct variables l) mc)end) + , cmd call (cmd.while e c) t m localsmap mc post ). + Proof. + eapply hlist_forall_foralls; intros g0 **. + eapply wp_while. + eexists measure, lt, (fun vi ti mi localsmapi mci => + exists gi li, localsmapi = reconstruct variables li /\ + match tuple.apply (hlist.apply (spec vi) gi ti mi) li mci with S_ => + S_.(1) /\ forall T M L MC, tuple.apply (S_.(2) T M) L MC-> + tuple.apply ((tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(2) T M) L MC end). + cbv [Markers.split Markers.left Markers.right] in *. + split; [assumption|]. + split. { exists v0, g0, l0. split. 1: eapply reconstruct_enforce; eassumption. split; eauto. } + intros vi ti mi lmapi mci (gi&?&?&?&Qi); subst. + destruct (hlist.foralls_forall (hlist.foralls_forall (Hbody vi) gi ti mi mci) _ ltac:(eassumption)) as (brv&brmc&?&X). + exists brv; exists brmc; split; [assumption|]. destruct X as (Htrue&Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lmapj mcj Hlj; eapply hlist.existss_exists in Hlj. + destruct Hlj as (lj&Elj&HE); eapply reconstruct_enforce in Elj; subst lmapj. + eapply hlist.existss_exists in HE. destruct HE as (l&?&?&?&HR). + pose proof fun T M => hlist.foralls_forall (HR T M); clear HR. + eauto 9. } + { pose proof fun t m mc => hlist.foralls_forall (Hpost t m mc); clear Hpost; eauto. } + Qed. + + Lemma tailrec_localsmap + {e c t} {m : mem} {l mc} {post : _ -> _->_->_-> Prop} + {measure : Type} (spec:_->_->_->_->_->(Prop*(_->_->_->_-> Prop))) lt + (Hwf : well_founded lt) + (v0 : measure) (P0 := (spec v0 t m l mc).(1)) (Hpre : P0) + (Q0 := (spec v0 t m l mc).(2)) + (Hbody : forall v t m l mc, + let S := spec v t m l mc in let (P, Q) := S in + P -> + exists br mc', expr m l e mc (eq (Datatypes.pair br mc')) /\ + (word.unsigned br <> 0%Z -> cmd call c t m l mc' + (fun t' m' l' mc''=> exists v', + let S' := spec v' t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc'') in let '(P', Q') := S' in + P' /\ + lt v' v /\ + forall T M L MC, Q' T M L MC -> Q T M L MC)) /\ + (word.unsigned br = 0%Z -> Q t m l (cost_loop_false isRegStr UNK (Some UNK) mc'))) + (Hpost : forall t m l mc, Q0 t m l mc -> post t m l mc) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists measure, lt, (fun v t m l mc => + let S := spec v t m l mc in let '(P, Q) := S in + P /\ forall T M L MC, Q T M L MC -> Q0 T M L MC). + split; [assumption|]. + cbv [Markers.split] in *. + split; [solve[eauto]|]. + intros vi ti mi li mci (?&Qi). + destruct (Hbody _ _ _ _ _ ltac:(eassumption)) as (br&?&?&X); exists br; eexists; split; [eassumption|]. + destruct X as (Htrue&Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lj mcj (vj&dP&?&dQ); eauto 9. } + { eauto. } + Qed. + + Definition with_bottom {T} R (x y : option T) := + match x, y with + | None, Some _ => True + | Some x, Some y => R x y + | _, _ => False + end. + Lemma well_founded_with_bottom {T} R (H : @well_founded T R) : well_founded (with_bottom R). + Proof. + intros [x|]; cycle 1. + { constructor; intros [] HX; cbv [with_bottom] in HX; contradiction. } + pattern x. revert x. eapply (@well_founded_ind _ _ H). intros. + constructor. intros [y|] pf; eauto. + constructor. intros [] []. + Qed. + + Lemma atleastonce_localsmap + {e c t} {m : mem} {l mc} {post : _->_->_->_-> Prop} + {measure : Type} (invariant:measure->_->_->_->_->Prop) lt + (Hwf : well_founded lt) + (Henter : exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) + /\ (word.unsigned br <> 0 -> exists v', invariant v' t m l brmc) + /\ (word.unsigned br = 0%Z -> post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc))) + (Hbody : forall v t m l mc, invariant v t m l mc -> + cmd call c t m l mc (fun t m l mc => + exists br brmc, expr m l e (cost_loop_true isRegStr UNK (Some UNK) mc) (eq (Datatypes.pair br brmc)) + /\ (word.unsigned br <> 0 -> exists v', invariant v' t m l brmc /\ lt v' v) + /\ (word.unsigned br = 0 -> post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc)))) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists (option measure), (with_bottom lt), (fun ov t m l mc => + exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) + /\ ((word.unsigned br <> 0 -> exists v, ov = Some v /\ invariant v t m l brmc) + /\ (word.unsigned br = 0 -> ov = None /\ post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc)))). + split; auto using well_founded_with_bottom; []. split. + { destruct Henter as [br [brmc [He [Henterm Henter0]]]]. + destruct (BinInt.Z.eq_dec (word.unsigned br) 0). + { exists None, br, brmc; split; trivial. + split; intros; try contradiction; split; eauto. } + { destruct (Henterm n) as [v Hinv]. + exists (Some v), br, brmc. + split; trivial; []; split; try contradiction. + exists v; split; trivial. } } + intros vi ti mi li mci (br&brmc&Ebr&Hcontinue&Hexit). + eexists; eexists; split; [eassumption|]; split. + { intros Hc; destruct (Hcontinue Hc) as (v&?&Hinv); subst. + eapply Proper_cmd; [ |eapply Hbody; eassumption]. + intros t' m' l' mc' (br'&brmc'&Ebr'&Hinv'&Hpost'). + destruct (BinInt.Z.eq_dec (word.unsigned br') 0). + { exists None; split; try constructor. + exists br', brmc'; split; trivial; []. + split; intros; try contradiction. + split; eauto. } + { destruct (Hinv' ltac:(trivial)) as (v'&inv'<v'v). + exists (Some v'); split; trivial. (* NOTE: this [trivial] simpl-reduces [with_bottom] *) + exists br', brmc'; split; trivial. + split; intros; try contradiction. + eexists; split; eauto. } } + eapply Hexit. + Qed. + + Lemma atleastonce + {e c t l mc} {m : mem} + (variables : list String.string) + {localstuple : tuple word (length variables)} + {Pl : enforce variables localstuple l} + {measure : Type} (invariant:measure->_->_->_->ufunc word (length variables) Prop) + lt (Hwf : well_founded lt) + {post : _->_->_->_-> Prop} + (Henter : exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) + /\ (word.unsigned br <> 0 -> exists v', tuple.apply (invariant v' t m brmc) localstuple) + /\ (word.unsigned br = 0%Z -> post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc))) + (Hbody : forall v t m mc, tuple.foralls (fun localstuple => + tuple.apply (invariant v t m mc) localstuple -> + cmd call c t m (reconstruct variables localstuple) mc (fun t m l mc => + exists br brmc, expr m l e (cost_loop_true isRegStr UNK (Some UNK) mc) (eq (Datatypes.pair br brmc)) + /\ (word.unsigned br <> 0 -> Markers.unique (Markers.left (tuple.existss (fun localstuple => enforce variables localstuple l /\ Markers.right (Markers.unique (exists v', tuple.apply (invariant v' t m brmc) localstuple /\ lt v' v)))))) + /\ (word.unsigned br = 0 -> post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc))))) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply (atleastonce_localsmap (fun v t m l mc => exists localstuple, Logic.and (enforce variables localstuple l) (tuple.apply (invariant v t m mc) localstuple))); eauto. + 1: { + destruct Henter as [br [brmc [He [Henterm Henter0]]]]. + do 3 eexists; eauto; split; eauto. + intro Hbr; destruct (Henterm Hbr); eauto. + } + intros vi ti mi li mci (?&X&Y). + specialize (Hbody vi ti mi mci). + eapply hlist.foralls_forall in Hbody. + specialize (Hbody Y). + rewrite <-(reconstruct_enforce _ _ _ X) in Hbody. + eapply Proper_cmd; [ |eapply Hbody]. + intros t' m' l' mc' (?&?&?&HH&?). + eexists; eexists; split; eauto. + split; intros; eauto. + specialize (HH ltac:(eauto)). + eapply hlist.existss_exists in HH; destruct HH as (?&?&?&?&?). + eexists; split; eauto. + Qed. + + Lemma tailrec_earlyout_localsmap + {e c t} {m : mem} {l mc} {post : _-> _->_->_-> Prop} + {measure : Type} (spec:_->_->_->_->_->(Prop*(_->_->_->_-> Prop))) lt + (Hwf : well_founded lt) + (v0 : measure) (P0 := (spec v0 t m l mc).(1)) (Hpre : P0) + (Q0 := (spec v0 t m l mc).(2)) + (Hbody : forall v t m l mc, + let S := spec v t m l mc in let (P, Q) := S in + P -> + exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) /\ + (word.unsigned br <> 0%Z -> cmd call c t m l brmc + (fun t' m' l' mc' => + (exists br brmc, + expr m' l' e (cost_loop_true isRegStr UNK (Some UNK) mc') (eq (Datatypes.pair br brmc)) + /\ word.unsigned br = 0 /\ Q t' m' l' (cost_loop_false isRegStr UNK (Some UNK) brmc)) \/ + exists v', let S' := spec v' t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc') in let '(P', Q') := S' in + P' /\ + lt v' v /\ + forall T M L MC, Q' T M L MC -> Q T M L MC)) /\ + (word.unsigned br = 0%Z -> Q t m l (cost_loop_false isRegStr UNK (Some UNK) brmc))) + (Hpost : forall t m l mc, Q0 t m l mc -> post t m l mc) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists (option measure), (with_bottom lt), (fun v t m l mc => + match v with + | None => exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) /\ word.unsigned br = 0 /\ Q0 t m l (cost_loop_false isRegStr UNK (Some UNK) brmc) + | Some v => + let S := spec v t m l mc in let '(P, Q) := S in + P /\ forall T M L MC, Q T M L MC -> Q0 T M L MC + end). + split; auto using well_founded_with_bottom; []; cbv [Markers.split] in *. + split. + { exists (Some v0); eauto. } + intros [vi|] ti mi li mci inv_i; [destruct inv_i as (?&Qi)|destruct inv_i as (br&Hebr&Hbr0&HQ)]. + { destruct (Hbody _ _ _ _ _ ltac:(eassumption)) as (br&brmc&?&X); exists br, brmc; split; [assumption|]. + destruct X as (Htrue&Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; eauto. + eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lj mcj [(br'&brmc'&Hbr'&Hz&HQ)|(vj&dP&?&dQ)]; + [exists None | exists (Some vj)]; cbn [with_bottom]; eauto 9. } + repeat esplit; destruct HQ; eauto; contradiction. + Qed. + + Lemma tailrec_earlyout + {e c t localsmap} {m : mem} + (ghosttypes : polymorphic_list.list Type) + (variables : list String.string) + {l0 : tuple word (length variables)} {mc} + {Pl : enforce variables l0 localsmap} + {post : _->_->_->_-> Prop} + {measure : Type} (spec:_->HList.arrows ghosttypes (_->_->ufunc word (length variables) (MetricLog->Prop*(_->_->ufunc word (length variables) (MetricLog->Prop))))) lt + (Hwf : well_founded lt) + (v0 : measure) + : hlist.foralls (fun (g0 : hlist ghosttypes) => forall + (Hpre : (tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(1)) + (Hbody : forall v, hlist.foralls (fun g => forall t m mc, tuple.foralls (fun l => + @dlet _ (fun _ => Prop) (reconstruct variables l) (fun localsmap : locals => + match tuple.apply (hlist.apply (spec v) g t m) l mc with S_ => + S_.(1) -> + Markers.unique (Markers.left (exists br brmc, expr m localsmap e mc (eq (Datatypes.pair br brmc)) /\ Markers.right ( + (word.unsigned br <> 0%Z -> cmd call c t m localsmap brmc + (fun t' m' localsmap' mc' => + Markers.unique (Markers.left (hlist.existss (fun l' => enforce variables l' localsmap' /\ Markers.right ( + Markers.unique (Markers.left (exists br brmc, expr m' localsmap' e (cost_loop_true isRegStr UNK (Some UNK) mc') (eq (Datatypes.pair br brmc)) /\ Markers.right ( word.unsigned br = 0 /\ tuple.apply (S_.(2) t' m') l' (cost_loop_false isRegStr UNK (Some UNK) brmc)) ) ) \/ + Markers.unique (Markers.left (hlist.existss (fun g' => exists v', + match tuple.apply (hlist.apply (spec v') g' t' m') l' (cost_loop_true isRegStr UNK (Some UNK) mc') with S' => + S'.(1) /\ Markers.right ( + lt v' v /\ + forall T M, hlist.foralls (fun L => forall MC, tuple.apply (S'.(2) T M) L MC -> tuple.apply (S_.(2) T M) L MC)) end))))))))) /\ + (word.unsigned br = 0%Z -> tuple.apply (S_.(2) t m) l (cost_loop_false isRegStr UNK (Some UNK) brmc)))))end)))) + (Hpost : match (tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(2) with Q0 => forall t m, hlist.foralls (fun l => forall mc, tuple.apply (Q0 t m) l mc -> post t m (reconstruct variables l) mc)end) + , cmd call (cmd.while e c) t m localsmap mc post). + Proof. + eapply hlist_forall_foralls; intros g0 **. + eapply wp_while. + eexists (option measure), (with_bottom lt), (fun vi ti mi localsmapi mci => + exists li, localsmapi = reconstruct variables li /\ + match vi with + | None => exists br brmc, expr mi localsmapi e mci (eq (Datatypes.pair br brmc)) + /\ word.unsigned br = 0 /\ tuple.apply ((tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(2) ti mi) li (cost_loop_false isRegStr UNK (Some UNK) brmc) + | Some vi => exists gi, match tuple.apply (hlist.apply (spec vi) gi ti mi) li mci with S_ => + S_.(1) /\ forall T M L MC, tuple.apply (S_.(2) T M) L MC -> + tuple.apply ((tuple.apply (hlist.apply (spec v0) g0 t m) l0 mc).(2) T M) L MC end end). + cbv [Markers.unique Markers.split Markers.left Markers.right] in *. + split; eauto using well_founded_with_bottom. + split. { exists (Some v0), l0. split. 1: eapply reconstruct_enforce; eassumption. exists g0; split; eauto. } + intros [vi|] ti mi lmapi mci. + 2: { intros (ld&Hd&br&brmc&Hbr&Hz&Hdone). + eexists; eexists; split; eauto. + split; intros; try contradiction. + subst; eapply (hlist.foralls_forall (Hpost ti mi) _ _ Hdone). } + intros (?&?&gi&?&Qi); subst. + destruct (hlist.foralls_forall (hlist.foralls_forall (Hbody vi) gi ti mi mci) _ ltac:(eassumption)) as (br&brmc&?&X). + exists br, brmc; split; [assumption|]. destruct X as (Htrue&Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lmapj mcj Hlj; eapply hlist.existss_exists in Hlj. + destruct Hlj as (lj&Elj&HE); eapply reconstruct_enforce in Elj; subst lmapj. + destruct HE as [(br'&brmc'&Hevalr'&Hz'&Hdone)|HE]. + { exists None; cbn. eauto 9. } + { eapply hlist.existss_exists in HE. destruct HE as (l&?&?&?&HR). + pose proof fun T M => hlist.foralls_forall (HR T M); clear HR. + eexists (Some _); eauto 9. } } + { pose proof fun t m => hlist.foralls_forall (Hpost t m); clear Hpost; eauto. } + Qed. + + Lemma while_zero_iterations {e c t l mc brmc} {m : mem} {post : _->_->_->_-> Prop} + (HCond: expr m l e mc (eq (Datatypes.pair (word.of_Z 0) brmc))) + (HPost: post t m l (cost_loop_false isRegStr UNK (Some UNK) brmc)) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply (while_localsmap (fun n t' m' l' mc' => t' = t /\ m' = m /\ l' = l /\ mc' = mc) (PeanoNat.Nat.lt_wf 0) 0%nat). + 1: unfold split; auto. intros *. intros (? & ? & ? & ?). subst. + eexists. eexists. split. 1: exact HCond. + rewrite Properties.word.unsigned_of_Z_0. + split; intros; congruence. + Qed. + + + (* Bedrock-style loop rule *) + Local Open Scope sep_scope. + Local Infix "*" := Separation.sep : type_scope. + Local Infix "==>" := Lift1Prop.impl1. + + Lemma tailrec_sep + e c t (m : mem) l mc (post : _->_->_->_-> Prop) + {measure : Type} P Q lt (Hwf : well_founded lt) (v0 : measure) R0 + (Hpre : (P v0 t l mc * R0) m) + (Hbody : forall v t m l mc R, (P v t l mc * R) m -> + exists br brmc, expr m l e mc (eq (Datatypes.pair br brmc)) /\ + (word.unsigned br <> 0%Z -> cmd call c t m l brmc + (fun t' m' l' mc' => exists v' dR, (P v' t' l' (cost_loop_true isRegStr UNK (Some UNK) mc') * (R * dR)) m' /\ + lt v' v /\ + forall T L MC, Q v' T L MC * dR ==> Q v T L MC)) /\ + (word.unsigned br = 0%Z -> (Q v t l (cost_loop_false isRegStr UNK (Some UNK) brmc) * R) m)) + (Hpost : forall t m l mc, (Q v0 t l mc * R0) m -> post t m l mc) + : cmd call (cmd.while e c) t m l mc post. + Proof. + eapply wp_while. + eexists measure, lt, (fun v t m l mc => exists R, (P v t l mc * R) m /\ + forall T L MC, Q v T L MC * R ==> Q v0 T L MC * R0). + split; [assumption|]. + split. { exists v0, R0. split; [assumption|]. intros. reflexivity. } + intros vi ti mi li mci (Ri&?&Qi). + destruct (Hbody _ _ _ _ _ _ ltac:(eassumption)) as (br&brmc&?&X); exists br, brmc; split; [assumption|]. + destruct X as (Htrue&Hfalse). split; intros Hbr; + [pose proof(Htrue Hbr)as Hpc|pose proof(Hfalse Hbr)as Hpc]; clear Hbr Htrue Hfalse. + { eapply Proper_cmd; [ |eapply Hpc]. + intros tj mj lj mcj (vj&dR&dP&?&dQ). + exists vj; split; [|assumption]. + exists (Ri * dR); split; [assumption|]. + intros. rewrite (sep_comm _ dR), <-(sep_assoc _ dR), dQ; trivial. } + { eapply Hpost, Qi, Hpc. } + Qed. + +End Loops. + +Ltac loop_simpl := + cbn [reconstruct map.putmany_of_list HList.tuple.to_list + HList.hlist.foralls HList.tuple.foralls + HList.hlist.existss HList.tuple.existss + HList.hlist.apply HList.tuple.apply HList.hlist + List.repeat Datatypes.length + HList.polymorphic_list.repeat HList.polymorphic_list.length + PrimitivePair.pair._1 PrimitivePair.pair._2] in *. diff --git a/bedrock2/src/bedrock2/MetricProgramLogic.v b/bedrock2/src/bedrock2/MetricProgramLogic.v new file mode 100644 index 000000000..813ad255f --- /dev/null +++ b/bedrock2/src/bedrock2/MetricProgramLogic.v @@ -0,0 +1,388 @@ +From coqutil.Tactics Require Import Tactics letexists eabstract rdelta reference_to_string ident_of_string. +Require Import coqutil.Map.Interface. +Require Import bedrock2.Syntax. +Require Import bedrock2.MetricWeakestPrecondition. +Require Import bedrock2.MetricWeakestPreconditionProperties. +Require Import bedrock2.MetricLoops. +Require Import bedrock2.Map.SeparationLogic bedrock2.Scalars. + +Definition spec_of (procname:String.string) := Semantics.env -> Prop. +Existing Class spec_of. + +Module Import Coercions. + Import Map.Interface Word.Interface BinInt. + Coercion Z.of_nat : nat >-> Z. + Coercion word.unsigned : word.rep >-> Z. + + Definition sepclause_of_map {key value map} (m : @map.rep key value map) + : map.rep -> Prop := Logic.eq m. + Coercion sepclause_of_map : Interface.map.rep >-> Funclass. +End Coercions. + +Goal True. + assert_succeeds epose (fun k v M (m : @Interface.map.rep k v M) => m _). +Abort. + +Section bindcmd. + Context {T : Type}. + Fixpoint bindcmd (c : Syntax.cmd) (k : Syntax.cmd -> T) {struct c} : T := + match c with + | cmd.cond e c1 c2 => bindcmd c1 (fun c1 => bindcmd c2 (fun c2 => let c := cmd.cond e c1 c2 in k c)) + | cmd.seq c1 c2 => bindcmd c1 (fun c1 => bindcmd c2 (fun c2 => let c := cmd.seq c1 c2 in k c)) + | cmd.while e c => bindcmd c (fun c => let c := cmd.while e c in k c) + | c => k c + end. +End bindcmd. + +(* TODO: use a deduplicating set instead of a list *) +Fixpoint callees (c : Syntax.cmd) : list String.string := + match c with + | cmd.cond _ c1 c2 | cmd.seq c1 c2 => callees c1 ++ callees c2 + | cmd.while _ c | cmd.stackalloc _ _ c => callees c + | cmd.call _ f _ => cons f nil + | _ => nil + end. + +Ltac assuming_correctness_of_in callees functions P := + lazymatch callees with + | nil => P + | cons ?f ?callees => + let f_spec := lazymatch constr:(_:spec_of f) with ?x => x end in + constr:(f_spec functions -> ltac:(let t := assuming_correctness_of_in callees functions P in exact t)) + end. +Require Import String List coqutil.Macros.ident_to_string. + +Ltac program_logic_goal_for_function proc := + let __ := constr:(proc : Syntax.func) in + constr_string_basename_of_constr_reference_cps ltac:(Tactics.head proc) ltac:(fun fname => + let spec := lazymatch constr:(_:spec_of fname) with ?s => s end in + exact (forall (functions : @map.rep _ _ Semantics.env) (EnvContains : map.get functions fname = Some proc), ltac:( + let callees := eval cbv in (callees (snd proc)) in + let s := assuming_correctness_of_in callees functions (spec functions) in + exact s))). +Definition program_logic_goal_for (_ : Syntax.func) (P : Prop) := P. + +Notation "program_logic_goal_for_function! proc" := (program_logic_goal_for proc ltac:( + program_logic_goal_for_function proc)) + (at level 10, only parsing). + +(* Users might want to override this with + Ltac normalize_body_of_function f ::= Tactics.rdelta.rdelta f. + in case cbv does more simplification than desired. *) +Ltac normalize_body_of_function f := eval cbv in f. + +Ltac bind_body_of_function f_ := + let f := normalize_body_of_function f_ in + let fargs := open_constr:(_) in + let frets := open_constr:(_) in + let fbody := open_constr:(_) in + let funif := open_constr:((fargs, frets, fbody)) in + unify f funif; + let G := lazymatch goal with |- ?G => G end in + let P := lazymatch eval pattern f_ in G with ?P _ => P end in + change (bindcmd fbody (fun c : Syntax.cmd => P (fargs, frets, c))); + cbv beta iota delta [bindcmd]; intros. + +(* note: f might have some implicit parameters (eg a record of constants) *) +Ltac enter f := + cbv beta delta [program_logic_goal_for]; + bind_body_of_function f; + lazymatch goal with |- ?s ?p => let s := rdelta s in change (s p); cbv beta end. + +Require coqutil.Map.SortedList. (* special-case eq_refl *) + +Ltac straightline_cleanup := + match goal with + (* TODO remove superfluous _ after .rep, but that will break some proofs that rely on + x not being cleared to instantiate evars with terms depending on x *) + | x : Word.Interface.word.rep _ |- _ => clear x + | x : Init.Byte.byte |- _ => clear x + | x : Semantics.trace |- _ => clear x + | x : Syntax.cmd |- _ => clear x + | x : Syntax.expr |- _ => clear x + | x : coqutil.Map.Interface.map.rep |- _ => clear x + | x : BinNums.Z |- _ => clear x + | x : unit |- _ => clear x + | x : bool |- _ => clear x + | x : list _ |- _ => clear x + | x : nat |- _ => clear x + (* same TODO as above *) + | x := _ : Word.Interface.word.rep _ |- _ => clear x + | x := _ : Init.Byte.byte |- _ => clear x + | x := _ : Semantics.trace |- _ => clear x + | x := _ : Syntax.cmd |- _ => clear x + | x := _ : Syntax.expr |- _ => clear x + | x := _ : coqutil.Map.Interface.map.rep |- _ => clear x + | x := _ : BinNums.Z |- _ => clear x + | x := _ : unit |- _ => clear x + | x := _ : bool |- _ => clear x + | x := _ : list _ |- _ => clear x + | x := _ : nat |- _ => clear x + | |- forall _, _ => intros + | |- let _ := _ in _ => intros + | |- dlet.dlet ?v (fun x => ?P) => change (let x := v in P); intros + | _ => progress (cbn [Semantics.interp_binop] in * ) + | H: exists _, _ |- _ => assert_succeeds progress destruct H as (_&_); destruct H + | H: _ /\ _ |- _ => destruct H + | x := ?y |- ?G => is_var y; subst x + | H: ?x = ?y |- _ => constr_eq x y; clear H + | H: ?x = ?y |- _ => is_var x; is_var y; assert_fails (idtac; let __ := eval cbv [x] in x in idtac); subst x + | H: ?x = ?y |- _ => is_var x; is_var y; assert_fails (idtac; let __ := eval cbv [y] in y in idtac); subst y + | H: ?x = ?v |- _ => + is_var x; + assert_fails (idtac; let __ := eval cbv delta [x] in x in idtac); + lazymatch v with context[x] => fail | _ => idtac end; + let x' := fresh x in + rename x into x'; + simple refine (let x := v in _); + change (x' = x) in H; + symmetry in H; + destruct H + end. + +Import MetricWeakestPrecondition. +Import coqutil.Map.Interface. + +Ltac straightline_stackalloc := + match goal with Hanybytes: Memory.anybytes ?a ?n ?mStack |- _ => + let m := match goal with H : map.split ?mCobined ?m mStack |- _ => m end in + let mCombined := match goal with H : map.split ?mCobined ?m mStack |- _ => mCobined end in + let Hsplit := match goal with H : map.split ?mCobined ?m mStack |- _ => H end in + let Hm := multimatch goal with H : _ m |- _ => H end in + let Hm' := fresh Hm in + let Htmp := fresh in + let Pm := match type of Hm with ?P m => P end in + assert_fails (assert (Separation.sep Pm (Array.array Separation.ptsto (Interface.word.of_Z (BinNums.Zpos BinNums.xH)) a _) mCombined) as _ by ecancel_assumption); + rename Hm into Hm'; + let stack := fresh "stack" in + let stack_length := fresh "length_" stack in (* MUST remain in context for deallocation *) + destruct (Array.anybytes_to_array_1 mStack a n Hanybytes) as (stack&Htmp&stack_length); + epose proof (ex_intro _ m (ex_intro _ mStack (conj Hsplit (conj Hm' Htmp))) + : Separation.sep _ (Array.array Separation.ptsto (Interface.word.of_Z (BinNums.Zpos BinNums.xH)) a _) mCombined) as Hm; + clear Htmp; (* note: we could clear more here if we assumed only one separation-logic description of each memory is present *) + try (let m' := fresh m in rename m into m'); rename mCombined into m; + ( assert (BinInt.Z.of_nat (Datatypes.length stack) = n) + by (rewrite stack_length; apply (ZifyInst.of_nat_to_nat_eq n)) + || fail 2 "negative stackalloc of size" n ) + end. + +Ltac straightline_stackdealloc := + lazymatch goal with |- exists _ _, Memory.anybytes ?a ?n _ /\ map.split ?m _ _ /\ _ => + let Hm := multimatch goal with Hm : _ m |- _ => Hm end in + let stack := match type of Hm with context [Array.array Separation.ptsto _ a ?stack] => stack end in + let length_stack := match goal with H : Datatypes.length stack = _ |- _ => H end in + let Hm' := fresh Hm in + pose proof Hm as Hm'; + let Psep := match type of Hm with ?P _ => P end in + let Htmp := fresh "Htmp" in + eassert (Lift1Prop.iff1 Psep (Separation.sep _ (Array.array Separation.ptsto (Interface.word.of_Z (BinNums.Zpos BinNums.xH)) a stack))) as Htmp + by ecancel || fail "failed to find stack frame in" Psep "using ecancel"; + eapply (fun m => proj1 (Htmp m)) in Hm; + let m' := fresh m in + rename m into m'; + let mStack := fresh in + destruct Hm as (m&mStack&Hsplit&Hm&Harray1); move Hm at bottom; + pose proof Array.array_1_to_anybytes _ _ _ Harray1 as Hanybytes; + rewrite length_stack in Hanybytes; + refine (ex_intro _ m (ex_intro _ mStack (conj Hanybytes (conj Hsplit _)))); + clear Htmp Hsplit mStack Harray1 Hanybytes + end. + +Ltac rename_to_different H := + idtac; + let G := fresh H "'0" in + rename H into G. +Ltac ensure_free H := + try rename_to_different H. + +Ltac eq_uniq_step := + match goal with + | |- ?x = ?y => + let x := rdelta x in + let y := rdelta y in + first [ is_evar x | is_evar y | constr_eq x y ]; exact eq_refl + | |- ?lhs = ?rhs => + let lh := head lhs in + is_constructor lh; + let rh := head rhs in + constr_eq lh rh; + f_equal (* NOTE: this is not sound, we really want just one f_equal application not a heuristic tactic *) + end. +Ltac eq_uniq := repeat eq_uniq_step. + +Ltac fwd_uniq_step := + match goal with + | |- exists x : ?T, _ => + let ev := open_constr:(match _ return T with x => x end) in + eexists ev; + let rec f := + tryif has_evar ev + then fwd_uniq_step + else idtac + in f + | |- _ /\ _ => split; [ solve [repeat fwd_uniq_step; eq_uniq] | ] + | _ => solve [ eq_uniq ] + end. +Ltac fwd_uniq := repeat fwd_uniq_step. + +Ltac straightline := + match goal with + | _ => straightline_cleanup + | |- program_logic_goal_for ?f _ => + enter f; intros; + match goal with + | H: map.get ?functions ?fname = Some _ |- _ => + eapply start_func; [exact H | clear H] + end; + cbv match beta delta [MetricWeakestPrecondition.func] + | |- MetricWeakestPrecondition.cmd _ (cmd.set ?s ?e) _ _ _ _ ?post => + unfold1_cmd_goal; cbv beta match delta [cmd_body]; + let __ := match s with String.String _ _ => idtac | String.EmptyString => idtac end in + ident_of_constr_string_cps s ltac:(fun x => + ensure_free x; + (* NOTE: keep this consistent with the [exists _, _ /\ _] case far below *) + letexists _ as x; split; [solve [repeat straightline]|]) + | |- cmd _ ?c _ _ _ _ ?post => + let c := eval hnf in c in + lazymatch c with + | cmd.while _ _ => fail + | cmd.cond _ _ _ => fail + | cmd.interact _ _ _ => fail + | _ => unfold1_cmd_goal; cbv beta match delta [cmd_body] + end + | |- @list_map _ _ (get _ _) _ _ _ => unfold1_list_map_goal; cbv beta match delta [list_map_body] + | |- @list_map _ _ (expr _ _) _ _ _ => unfold1_list_map_goal; cbv beta match delta [list_map_body] + | |- @list_map _ _ _ nil _ _ => cbv beta match fix delta [list_map list_map_body] + | |- expr _ _ _ _ _ => unfold1_expr_goal; cbv beta match delta [expr_body] + | |- dexpr _ _ _ _ _ => cbv beta delta [dexpr] + | |- dexprs _ _ _ _ _ => cbv beta delta [dexprs] + | |- literal _ _ _ => cbv beta delta [literal] + | |- @get ?w ?W ?L ?l ?x ?mc ?P => + let get' := eval cbv [get] in @get in + change (get' w W L l x mc P); cbv beta + | |- load _ _ _ _ _ => cbv beta delta [load] + | |- @MetricLoops.enforce ?width ?word ?locals ?names ?values ?map => + let values := eval cbv in values in + change (@MetricLoops.enforce width word locals names values map); + exact (conj (eq_refl values) eq_refl) + | |- @eq (@coqutil.Map.Interface.map.rep String.string Interface.word.rep _) _ _ => + eapply SortedList.eq_value; exact eq_refl + | |- @map.get String.string Interface.word.rep ?M ?m ?k = Some ?e' => + let e := rdelta e' in + is_evar e; + once (let v := multimatch goal with x := context[@map.put _ _ M _ k ?v] |- _ => v end in + (* cbv is slower than this, cbv with whitelist would have an enormous whitelist, cbv delta for map is slower than this, generalize unrelated then cbv is slower than this, generalize then vm_compute is slower than this, lazy is as slow as this: *) + unify e v; exact (eq_refl (Some v))) + | |- @coqutil.Map.Interface.map.get String.string Interface.word.rep _ _ _ = Some ?v => + let v' := rdelta v in is_evar v'; (change v with v'); exact eq_refl + | |- (_, _) = (_, _) => f_equal (* NOTE: metrics-only case *) + | |- ?x = ?y => + let y := rdelta y in is_evar y; change (x=y); exact eq_refl + | |- ?x = ?y => + let x := rdelta x in is_evar x; change (x=y); exact eq_refl + | |- ?x = ?y => + let x := rdelta x in let y := rdelta y in constr_eq x y; exact eq_refl + | |- store Syntax.access_size.one _ _ _ _ => + eapply Scalars.store_one_of_sep; [solve[ecancel_assumption]|] + | |- store Syntax.access_size.two _ _ _ _ => + eapply Scalars.store_two_of_sep; [solve[ecancel_assumption]|] + | |- store Syntax.access_size.four _ _ _ _ => + eapply Scalars.store_four_of_sep; [solve[ecancel_assumption]|] + | |- store Syntax.access_size.word _ _ _ _ => + eapply Scalars.store_word_of_sep; [solve[ecancel_assumption]|] + | |- bedrock2.Memory.load Syntax.access_size.one ?m ?a = Some ?ev => + try subst ev; refine (@Scalars.load_one_of_sep _ _ _ _ _ _ _ _ _ _); ecancel_assumption + | |- @bedrock2.Memory.load _ ?word ?mem Syntax.access_size.two ?m ?a = Some ?ev => + try subst ev; refine (@Scalars.load_two_of_sep _ word _ mem _ a _ _ m _); ecancel_assumption + | |- @bedrock2.Memory.load _ ?word ?mem Syntax.access_size.four ?m ?a = Some ?ev => + try subst ev; refine (@Scalars.load_four_of_sep_32bit _ word _ mem _ eq_refl a _ _ m _); ecancel_assumption + | |- @bedrock2.Memory.load _ ?word ?mem Syntax.access_size.four ?m ?a = Some ?ev => + try subst ev; refine (@Scalars.load_four_of_sep _ word _ mem _ a _ _ m _); ecancel_assumption + | |- @bedrock2.Memory.load _ ?word ?mem Syntax.access_size.word ?m ?a = Some ?ev => + try subst ev; refine (@Scalars.load_word_of_sep _ word _ mem _ a _ _ m _); ecancel_assumption + | |- exists l', Interface.map.of_list_zip ?ks ?vs = Some l' /\ _ => + letexists; split; [exact eq_refl|] (* TODO: less unification here? *) + | |- exists l', Interface.map.putmany_of_list_zip ?ks ?vs ?l = Some l' /\ _ => + letexists; split; [exact eq_refl|] (* TODO: less unification here? *) + | _ => fwd_uniq_step + | |- exists x, ?P /\ ?Q => + let x := fresh x in refine (let x := _ in ex_intro (fun x => P /\ Q) x _); + split; [solve [repeat straightline]|] + (* NOTE: metrics only case; maybe try to unify with non-metrics? *) + | |- exists x y, ?P /\ ?Q => + eexists; eexists; split; [solve [repeat straightline]|] + (* eexists instead of letexists ensures unification of (?a,?b) = (const,const) + does not unfold the const aggressively (e.g. word to Naive) *) + | |- exists x, Markers.split (?P /\ ?Q) => + let x := fresh x in refine (let x := _ in ex_intro (fun x => P /\ Q) x _); + split; [solve [repeat straightline]|] + | |- Markers.unique (exists x, Markers.split (?P /\ ?Q)) => + let x := fresh x in refine (let x := _ in ex_intro (fun x => P /\ Q) x _); + split; [solve [repeat straightline]|] + | |- Markers.unique (Markers.left ?G) => + change G; + unshelve (idtac; repeat match goal with + | |- Markers.split (?P /\ Markers.right ?Q) => + split; [eabstract (repeat straightline) | change Q] + | |- exists _, _ => letexists + end); [] + | |- Markers.split ?G => change G; split + | |- True => exact I + | |- False \/ _ => right + | |- _ \/ False => left + | |- BinInt.Z.modulo ?z (Memory.bytes_per_word _) = BinInt.Z0 /\ _ => + lazymatch Coq.setoid_ring.InitialRing.isZcst z with + | true => split; [exact eq_refl|] + end + | |- _ => straightline_stackalloc + | |- _ => straightline_stackdealloc + | |- context[sep (sep _ _) _] => progress (flatten_seps_in_goal; cbn [seps]) + | H : context[sep (sep _ _) _] |- _ => progress (flatten_seps_in H; cbn [seps] in H) + end. + +(* TODO: once we can automatically prove some calls, include the success-only version of this in [straightline] *) +Ltac straightline_call := + lazymatch goal with + | |- MetricWeakestPrecondition.call ?functions ?callee _ _ _ _ => + let callee_spec := lazymatch constr:(_:spec_of callee) with ?s => s end in + let Hcall := lazymatch goal with H: callee_spec functions |- _ => H end in + eapply MetricWeakestPreconditionProperties.Proper_call; cycle -1; + [ eapply Hcall | try eabstract (solve [Morphisms.solve_proper]) .. ]; + [ .. | intros ? ? ? ?] + end. + +Ltac current_trace_mem_locals := + lazymatch goal with + | |- MetricWeakestPrecondition.cmd _ _ ?t ?m ?l _ _ => constr:((t, m, l)) + end. + +Ltac seprewrite Hrw := + let tml := current_trace_mem_locals in + let m := lazymatch tml with (_, ?m, _) => m end in + let H := multimatch goal with H: _ m |- _ => H end in + seprewrite_in Hrw H. +Ltac seprewrite_by Hrw tac := + let tml := current_trace_mem_locals in + let m := lazymatch tml with (_, ?m, _) => m end in + let H := multimatch goal with H: _ m |- _ => H end in + seprewrite_in_by Hrw H tac. + +Ltac show_program := + lazymatch goal with + | |- @cmd ?width ?BW ?word ?mem ?locals ?ext_spec ?E ?c ?F ?G ?H ?I => + let c' := eval cbv in c in + change (@cmd width BW word mem locals ext_spec E (fst (c, c')) F G H I) + end. + +Ltac subst_words := + repeat match goal with x := _ : coqutil.Word.Interface.word.rep |- _ => subst x end. + +Require Import coqutil.Tactics.eplace Coq.setoid_ring.Ring_tac. +Ltac ring_simplify_words := + subst_words; + repeat match goal with H : context [?w] |- _ => + let __ := constr:(w : Interface.word.rep) in + progress eplace w with _ in H by (ring_simplify; reflexivity) end; + repeat match goal with |- context [?w] => + let __ := constr:(w : Interface.word.rep) in + progress eplace w with _ by (ring_simplify; reflexivity) end. diff --git a/bedrock2/src/bedrock2/MetricSemantics.v b/bedrock2/src/bedrock2/MetricSemantics.v index 41936c903..059ae32ff 100644 --- a/bedrock2/src/bedrock2/MetricSemantics.v +++ b/bedrock2/src/bedrock2/MetricSemantics.v @@ -7,9 +7,12 @@ Require Import BinIntDef coqutil.Word.Interface coqutil.Word.Bitwidth. Require Export bedrock2.Memory. Require Import Coq.Lists.List. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import bedrock2.Semantics. Require Import Coq.Lists.List. +Local Notation UNK := String.EmptyString. + Section semantics. Context {width: Z} {BW: Bitwidth width} {word: word.word width} {mem: map.map word byte}. Context {locals: map.map String.string word}. @@ -23,37 +26,31 @@ Section semantics. Local Notation "' x <- a | y ; f" := (match a with x => f | _ => y end) (right associativity, at level 70, x pattern). + (* TODO XXX possibly be a bit smarter about whether things are registers, + for tighter metrics bounds at bedrock2 level *) Fixpoint eval_expr (e : expr) (mc : metrics) : option (word * metrics) := match e with - | expr.literal v => Some (word.of_Z v, addMetricInstructions 8 - (addMetricLoads 8 mc)) + | expr.literal v => Some (word.of_Z v, cost_lit isRegStr UNK mc) | expr.var x => match map.get l x with - | Some v => Some (v, addMetricInstructions 1 - (addMetricLoads 2 mc)) + | Some v => Some (v, cost_set isRegStr UNK x mc) | None => None end | expr.inlinetable aSize t index => 'Some (index', mc') <- eval_expr index mc | None; 'Some v <- load aSize (map.of_list_word t) index' | None; - Some (v, (addMetricInstructions 3 - (addMetricLoads 4 - (addMetricJumps 1 mc')))) + Some (v, cost_inlinetable isRegStr UNK UNK mc') | expr.load aSize a => 'Some (a', mc') <- eval_expr a mc | None; 'Some v <- load aSize m a' | None; - Some (v, addMetricInstructions 1 - (addMetricLoads 2 mc')) + Some (v, cost_load isRegStr UNK UNK mc') | expr.op op e1 e2 => 'Some (v1, mc') <- eval_expr e1 mc | None; 'Some (v2, mc'') <- eval_expr e2 mc' | None; - Some (interp_binop op v1 v2, addMetricInstructions 2 - (addMetricLoads 2 mc'')) + Some (interp_binop op v1 v2, cost_op isRegStr UNK UNK UNK mc'') | expr.ite c e1 e2 => 'Some (vc, mc') <- eval_expr c mc | None; eval_expr (if word.eqb vc (word.of_Z 0) then e2 else e1) - (addMetricInstructions 2 - (addMetricLoads 2 - (addMetricJumps 1 mc'))) + (cost_if isRegStr UNK (Some UNK) mc') end. Fixpoint eval_call_args (arges : list expr) (mc : metrics) := @@ -88,8 +85,7 @@ Module exec. Section WithParams. | set x e t m l mc post v mc' (_ : eval_expr m l e mc = Some (v, mc')) - (_ : post t m (map.put l x v) (addMetricInstructions 1 - (addMetricLoads 1 mc'))) + (_ : post t m (map.put l x v) (cost_set isRegStr x UNK mc')) : exec (cmd.set x e) t m l mc post | unset x t m l mc post @@ -100,9 +96,7 @@ Module exec. Section WithParams. a mc' (_ : eval_expr m l ea mc = Some (a, mc')) v mc'' (_ : eval_expr m l ev mc' = Some (v, mc'')) m' (_ : store sz m a v = Some m') - (_ : post t m' l (addMetricInstructions 1 - (addMetricLoads 1 - (addMetricStores 1 mc'')))) + (_ : post t m' l (cost_store isRegStr UNK UNK mc'')) : exec (cmd.store sz ea ev) t m l mc post | stackalloc x n body t mSmall l mc post @@ -110,7 +104,7 @@ Module exec. Section WithParams. (_ : forall a mStack mCombined, anybytes a n mStack -> map.split mCombined mSmall mStack -> - exec body t mCombined (map.put l x a) (addMetricInstructions 1 (addMetricLoads 1 mc)) + exec body t mCombined (map.put l x a) (cost_stackalloc isRegStr x mc) (fun t' mCombined' l' mc' => exists mSmall' mStack', anybytes a n mStack' /\ @@ -120,17 +114,13 @@ Module exec. Section WithParams. | if_true t m l mc e c1 c2 post v mc' (_ : eval_expr m l e mc = Some (v, mc')) (_ : word.unsigned v <> 0) - (_ : exec c1 t m l (addMetricInstructions 2 - (addMetricLoads 2 - (addMetricJumps 1 mc'))) post) + (_ : exec c1 t m l (cost_if isRegStr UNK (Some UNK) mc') post) : exec (cmd.cond e c1 c2) t m l mc post | if_false e c1 c2 t m l mc post v mc' (_ : eval_expr m l e mc = Some (v, mc')) (_ : word.unsigned v = 0) - (_ : exec c2 t m l (addMetricInstructions 2 - (addMetricLoads 2 - (addMetricJumps 1 mc'))) post) + (_ : exec c2 t m l (cost_if isRegStr UNK (Some UNK) mc') post) : exec (cmd.cond e c1 c2) t m l mc post | seq c1 c2 t m l mc post @@ -141,9 +131,7 @@ Module exec. Section WithParams. t m l mc post v mc' (_ : eval_expr m l e mc = Some (v, mc')) (_ : word.unsigned v = 0) - (_ : post t m l (addMetricInstructions 1 - (addMetricLoads 1 - (addMetricJumps 1 mc')))) + (_ : post t m l (cost_loop_false isRegStr UNK (Some UNK) mc')) : exec (cmd.while e c) t m l mc post | while_true e c t m l mc post @@ -151,20 +139,18 @@ Module exec. Section WithParams. (_ : word.unsigned v <> 0) mid (_ : exec c t m l mc' mid) (_ : forall t' m' l' mc'', mid t' m' l' mc'' -> - exec (cmd.while e c) t' m' l' (addMetricInstructions 2 - (addMetricLoads 2 - (addMetricJumps 1 mc''))) post) + exec (cmd.while e c) t' m' l' (cost_loop_true isRegStr UNK (Some UNK) mc'') post) : exec (cmd.while e c) t m l mc post | call binds fname arges t m l mc post params rets fbody (_ : map.get e fname = Some (params, rets, fbody)) args mc' (_ : eval_call_args m l arges mc = Some (args, mc')) lf (_ : map.of_list_zip params args = Some lf) - mid (_ : exec fbody t m lf (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc')))) mid) + mid (_ : exec fbody t m lf mc' mid) (_ : forall t' m' st1 mc'', mid t' m' st1 mc'' -> exists retvs, map.getmany_of_list st1 rets = Some retvs /\ exists l', map.putmany_of_list_zip binds retvs l = Some l' /\ - post t' m' l' (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc''))))) + post t' m' l' (cost_call PreSpill mc'')) : exec (cmd.call binds fname arges) t m l mc post | interact binds action arges t m l mc post @@ -175,9 +161,7 @@ Module exec. Section WithParams. exists l', map.putmany_of_list_zip binds resvals l = Some l' /\ forall m', map.split m' mKeep mReceive -> post (cons ((mGive, action, args), (mReceive, resvals)) t) m' l' - (addMetricInstructions 1 - (addMetricStores 1 - (addMetricLoads 2 mc')))) + (cost_interact PreSpill mc')) : exec (cmd.interact binds action arges) t m l mc post . diff --git a/bedrock2/src/bedrock2/MetricWeakestPrecondition.v b/bedrock2/src/bedrock2/MetricWeakestPrecondition.v new file mode 100644 index 000000000..d64c4ad6b --- /dev/null +++ b/bedrock2/src/bedrock2/MetricWeakestPrecondition.v @@ -0,0 +1,338 @@ +Require Import coqutil.Macros.subst coqutil.Macros.unique coqutil.Map.Interface coqutil.Map.OfListWord. +Require Import Coq.ZArith.BinIntDef coqutil.Word.Interface coqutil.Word.Bitwidth. +Require Import coqutil.dlet bedrock2.Syntax bedrock2.Semantics. +Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. +Require Import bedrock2.MetricSemantics. + +Section WeakestPrecondition. + Context {width: Z} {BW: Bitwidth width} {word: word.word width} {mem: map.map word Byte.byte}. + Context {locals: map.map String.string word}. + Context {ext_spec: ExtSpec}. + Implicit Types (t : trace) (m : mem) (l : locals). + + Local Notation metrics := MetricLog. + Local Notation UNK := String.EmptyString. + + (* TODO XXX address inconsistency in where metrics are added *) + Definition literal v mc (post : (word * metrics) -> Prop) : Prop := + dlet! v := word.of_Z v in post (v, cost_lit isRegStr UNK mc). + Definition get (l : locals) (x : String.string) mc (post : (word * metrics) -> Prop) : Prop := + exists v, map.get l x = Some v /\ post (v, cost_set isRegStr UNK x mc). + Definition load s m a mc (post: (word * metrics) -> Prop) : Prop := + exists v, load s m a = Some v /\ post (v, mc). + Definition store sz m a v post := + exists m', store sz m a v = Some m' /\ post m'. + + Section WithMemAndLocals. + Context (m : mem) (l : locals). + Definition expr_body (rec : _->_->(word*metrics->Prop)->Prop) (e : Syntax.expr) (mc : metrics) (post : word * metrics -> Prop) : Prop := + match e with + | expr.literal v => + literal v mc post + | expr.var x => + get l x mc post + | expr.op op e1 e2 => + rec e1 mc (fun '(v1, mc') => + rec e2 mc' (fun '(v2, mc'') => + post (interp_binop op v1 v2, cost_op isRegStr UNK UNK UNK mc''))) + | expr.load s e => + rec e mc (fun '(a, mc') => + load s m a (cost_load isRegStr UNK UNK mc') post) + | expr.inlinetable s t e => + rec e mc (fun '(a, mc') => + load s (map.of_list_word t) a (cost_inlinetable isRegStr UNK UNK mc') post) + | expr.ite c e1 e2 => + rec c mc (fun '(b, mc') => rec (if word.eqb b (word.of_Z 0) then e2 else e1) (cost_if isRegStr UNK (Some UNK) mc') post) + end. + Fixpoint expr e := expr_body expr e. + End WithMemAndLocals. + + Section WithF. + Context {A B} (f: A -> metrics -> (B * metrics -> Prop) -> Prop). + Definition list_map_body rec (xs : list A) (mc : metrics) (post : list B * metrics -> Prop) : Prop := + match xs with + | nil => post (nil, mc) + | cons x xs' => + f x mc (fun '(y, mc') => + rec xs' mc' (fun '(ys', mc'') => + post (cons y ys', mc''))) + end. + Fixpoint list_map xs := list_map_body list_map xs. + End WithF. + + Section WithFunctions. + + Context (e: env). + Context (call : String.string -> trace -> mem -> list word -> metrics -> (trace -> mem -> list word -> metrics -> Prop) -> Prop). + Definition dexpr m l e mc v := expr m l e mc (eq v). + Definition dexprs m l es mc vs := list_map (expr m l) es mc (eq vs). +(* All cases except cmd.while and cmd.call can be denoted by structural recursion + over the syntax. + For cmd.while and cmd.call, we fall back to the operational semantics *) + Definition cmd_body (rec:_->_->_->_->_->_->Prop) (c : cmd) (t : trace) (m : mem) (l : locals) (mc : metrics) + (post : trace -> mem -> locals -> metrics -> Prop) : Prop := + (* give value of each pure expression when stating its subproof *) + match c with + | cmd.skip => post t m l mc + | cmd.set x ev => + exists v mc', dexpr m l ev mc (v, mc') /\ + dlet! l := map.put l x v in + post t m l (cost_set isRegStr x UNK mc') + | cmd.unset x => + dlet! l := map.remove l x in + post t m l mc + | cmd.store sz ea ev => + exists a mc', dexpr m l ea mc (a, mc') /\ + exists v mc'', dexpr m l ev mc' (v, mc'') /\ + store sz m a v (fun m => + post t m l (cost_store isRegStr UNK UNK mc'')) + | cmd.stackalloc x n c => + Z.modulo n (bytes_per_word width) = 0 /\ + forall a mStack mCombined, + anybytes a n mStack -> map.split mCombined m mStack -> + dlet! l := map.put l x a in + rec c t mCombined l (cost_stackalloc isRegStr x mc) + (fun t' mCombined' l' mc' => + exists m' mStack', + anybytes a n mStack' /\ map.split mCombined' m' mStack' /\ + post t' m' l' mc') + | cmd.cond br ct cf => + exists v mc', dexpr m l br mc (v, mc') /\ + dlet! mc'' := cost_if isRegStr UNK (Some UNK) mc' in + (word.unsigned v <> 0%Z -> rec ct t m l mc'' post) /\ + (word.unsigned v = 0%Z -> rec cf t m l mc'' post) + | cmd.seq c1 c2 => + rec c1 t m l mc (fun t m l mc => rec c2 t m l mc post) + | cmd.while _ _ => MetricSemantics.exec e c t m l mc post + | cmd.call binds fname arges => + exists args mc', dexprs m l arges mc (args, mc') /\ + MetricSemantics.call e fname t m args mc' (fun t m rets mc'' => + exists l', map.putmany_of_list_zip binds rets l = Some l' /\ post t m l' (cost_call PreSpill mc'')) + | cmd.interact binds action arges => + exists args mc', dexprs m l arges mc (args, mc') /\ + exists mKeep mGive, map.split m mKeep mGive /\ + ext_spec t mGive action args (fun mReceive rets => + exists l', map.putmany_of_list_zip binds rets l = Some l' /\ + forall m', map.split m' mKeep mReceive -> + post (cons ((mGive, action, args), (mReceive, rets)) t) m' l' (cost_interact PreSpill mc')) + end. + + Fixpoint cmd c := cmd_body cmd c. + End WithFunctions. + + Definition func call '(innames, outnames, c) (t : trace) (m : mem) (args : list word) (mc : metrics) (post : trace -> mem -> list word -> metrics -> Prop) := +exists l, map.of_list_zip innames args = Some l /\ + cmd call c t m l mc (fun t m l mc => + list_map (get l) outnames mc (fun '(rets, _) => + post t m rets mc)). + + Definition program := cmd. + + (* + Definition call_body rec (functions : list (String.string * (list String.string * list String.string * cmd.cmd))) + (fname : String.string) (t : trace) (m : mem) (args : list word) (mc: metrics) + (post : trace -> mem -> list word -> metrics -> Prop) : Prop := + match functions with + | nil => False + | cons (f, decl) functions => + if String.eqb f fname + then func (rec functions) decl t m args mc post + else rec functions fname t m args mc post + end. + Fixpoint call functions := call_body call functions. + + Definition program funcs main t m l mc post : Prop := cmd (call funcs) main t m l mc post. *) + +End WeakestPrecondition. +Notation call := MetricSemantics.call (only parsing). + +Ltac unfold1_cmd e := + lazymatch e with + @cmd ?width ?BW ?word ?mem ?locals ?ext_spec ?CA ?c ?t ?m ?l ?mc ?post => + let c := eval hnf in c in + constr:(@cmd_body width BW word mem locals ext_spec CA + (@cmd width BW word mem locals ext_spec CA) c t m l mc post) + end. +Ltac unfold1_cmd_goal := + let G := lazymatch goal with |- ?G => G end in + let G := unfold1_cmd G in + change G. + +Ltac unfold1_expr e := + lazymatch e with + @expr ?width ?word ?mem ?locals ?m ?l ?arg ?mc ?post => + let arg := eval hnf in arg in + constr:(@expr_body width word mem locals m l (@expr width word mem locals m l) arg mc post) + end. +Ltac unfold1_expr_goal := + let G := lazymatch goal with |- ?G => G end in + let G := unfold1_expr G in + change G. + +Ltac unfold1_list_map e := + lazymatch e with + @list_map ?A ?B ?P ?arg ?mc ?post => + let arg := eval hnf in arg in + constr:(@list_map_body A B P (@list_map A B P) arg mc post) + end. +Ltac unfold1_list_map_goal := + let G := lazymatch goal with |- ?G => G end in + let G := unfold1_list_map G in + change G. + +(* +Ltac unfold1_call e := + lazymatch e with + @call ?width ?BW ?word ?mem ?locals ?ext_spec ?fs ?fname ?t ?m ?l ?mc ?post => + let fs := eval hnf in fs in + constr:(@call_body width BW word mem locals ext_spec + (@call width BW word mem locals ext_spec) fs fname t m l mc post) + end. +Ltac unfold1_call_goal := + let G := lazymatch goal with |- ?G => G end in + let G := unfold1_call G in + change G. + + *) + +Import Coq.ZArith.ZArith. + +Notation "'fnspec!' name a0 .. an '/' g0 .. gn '~>' r0 .. rn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall a0, + .. (forall an, + (forall g0, + .. (forall gn, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem (cons a0 .. (cons an nil) ..) mc + (fun tr' mem' rets mc' => + (exists r0, + .. (exists rn, + rets = (cons r0 .. (cons rn nil) ..) /\ + post) ..)))) ..)) ..)) + (at level 200, + name at level 0, + a0 binder, an binder, + g0 binder, gn binder, + r0 closed binder, rn closed binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name a0 .. an '/' g0 .. gn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall a0, + .. (forall an, + (forall g0, + .. (forall gn, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem (cons a0 .. (cons an nil) ..) mc + (fun tr' mem' rets mc' => + rets = nil /\ post))) ..)) ..)) + (at level 200, + name at level 0, + a0 binder, an binder, + g0 binder, gn binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name a0 .. an '~>' r0 .. rn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall a0, + .. (forall an, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem (cons a0 .. (cons an nil) ..) mc + (fun tr' mem' rets mc' => + (exists r0, + .. (exists rn, + rets = (cons r0 .. (cons rn nil) ..) /\ + post) ..)))) ..)) + (at level 200, + name at level 0, + a0 binder, an binder, + r0 closed binder, rn closed binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name '/' g0 .. gn '~>' r0 .. rn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall an, + (forall g0, + .. (forall gn, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem nil mc + (fun tr' mem' rets mc' => + (exists r0, + .. (exists rn, + rets = (cons r0 .. (cons rn nil) ..) /\ + post) ..)))) ..))) + (at level 200, + name at level 0, + g0 binder, gn binder, + r0 closed binder, rn closed binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name a0 .. an ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall a0, + .. (forall an, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem (cons a0 .. (cons an nil) ..) mc + (fun tr' mem' rets mc' => + rets = nil /\ post))) ..)) + (at level 200, + name at level 0, + a0 binder, an binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name '/' g0 .. gn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall g0, + .. (forall gn, + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem nil mc + (fun tr' mem' rets mc' => + rets = nil /\ post))) ..)) + (at level 200, + name at level 0, + g0 binder, gn binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). + +Notation "'fnspec!' name '~>' r0 .. rn ',' '{' 'requires' tr mem mc := pre ';' 'ensures' tr' mem' mc' ':=' post '}'" := + (fun functions => + (forall tr mem mc, + pre -> + MetricWeakestPrecondition.call + functions name tr mem nil mc + (fun tr' mem' rets mc' => + (exists r0, + .. (exists rn, + rets = (cons r0 .. (cons rn nil) ..) /\ + post) ..)))) + (at level 200, + name at level 0, + r0 closed binder, rn closed binder, + tr name, tr' name, mem name, mem' name, mc name, mc' name, + pre at level 200, + post at level 200). diff --git a/bedrock2/src/bedrock2/MetricWeakestPreconditionProperties.v b/bedrock2/src/bedrock2/MetricWeakestPreconditionProperties.v new file mode 100644 index 000000000..b0e16f67e --- /dev/null +++ b/bedrock2/src/bedrock2/MetricWeakestPreconditionProperties.v @@ -0,0 +1,378 @@ +Require Import coqutil.Macros.subst coqutil.Macros.unique coqutil.Map.Interface coqutil.Word.Properties. +Require Import coqutil.Word.Bitwidth. +Require Import bedrock2.MetricLogging. +Require bedrock2.MetricWeakestPrecondition. + +Require Import Coq.Classes.Morphisms. + +Section MetricWeakestPrecondition. + Context {width} {BW: Bitwidth width} {word: word.word width} {mem: map.map word Byte.byte}. + Context {locals: map.map String.string word}. + Context {ext_spec: Semantics.ExtSpec}. + + Ltac ind_on X := + intros; + (* Note: Comment below dates from when we were using a parameter record p *) + (* Note: "before p" means actually "after p" when reading from top to bottom, because, + as the manual points out, "before" and "after" are with respect to the direction of + the move, and we're moving hypotheses upwards here. + We need to make sure not to revert/clear p, because the other lemmas depend on it. + If we still reverted/cleared p, we'd get errors like + "Error: Proper_load depends on the variable p which is not declared in the context." + when trying to use Proper_load, or, due to COQBUG https://github.com/coq/coq/issues/11487, + we'd get a typechecking failure at Qed time. *) + repeat match goal with x : ?T |- _ => first + [ constr_eq T X; move x before ext_spec + | constr_eq T X; move x before locals + | constr_eq T X; move x at top + | revert x ] end; + match goal with x : X |- _ => induction x end; + intros. + + Local Hint Mode word.word - : typeclass_instances. + + (* we prove weakening lemmas for all WP definitions in a syntax-directed fashion, + * moving from postcondition towards precondition one logical connective at a time. *) + Global Instance Proper_literal : Proper (pointwise_relation _ (pointwise_relation _ ((pointwise_relation _ Basics.impl) ==> Basics.impl))) MetricWeakestPrecondition.literal. + Proof using. clear. cbv [MetricWeakestPrecondition.literal]; cbv [Proper respectful pointwise_relation Basics.impl dlet.dlet]. eauto. Qed. + + Global Instance Proper_get : Proper (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ ((pointwise_relation _ Basics.impl) ==> Basics.impl)))) MetricWeakestPrecondition.get. + Proof using. clear. cbv [MetricWeakestPrecondition.get]; cbv [Proper respectful pointwise_relation Basics.impl]; intros * ? (?&?&?); eauto. Qed. + + Global Instance Proper_load : Proper (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ ((pointwise_relation _ Basics.impl) ==> Basics.impl))))) MetricWeakestPrecondition.load. + Proof using. clear. cbv [MetricWeakestPrecondition.load]; cbv [Proper respectful pointwise_relation Basics.impl]; intros * ? (?&?&?); eauto. Qed. + + Global Instance Proper_store : Proper (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ ((pointwise_relation _ Basics.impl) ==> Basics.impl))))) MetricWeakestPrecondition.store. + Proof using. clear. cbv [MetricWeakestPrecondition.store]; cbv [Proper respectful pointwise_relation Basics.impl]; intros * ? (?&?&?); eauto. Qed. + + Global Instance Proper_expr : Proper (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ ((pointwise_relation _ Basics.impl) ==> Basics.impl))))) MetricWeakestPrecondition.expr. + Proof using. + clear. + cbv [Proper respectful pointwise_relation Basics.impl]; ind_on Syntax.expr.expr; + cbn in *; intuition (try typeclasses eauto with core). + { eapply Proper_literal; eauto. } + { eapply Proper_get; eauto. } + { eapply IHa1; eauto; intuition idtac. destruct a4. eapply Proper_load; eauto using Proper_load. } + { eapply IHa1; eauto; intuition idtac. destruct a4. eapply Proper_load; eauto using Proper_load. } + { eapply IHa1_1; eauto. destruct a1. eapply IHa1_2; eauto. destruct a1. eauto. } + {eapply IHa1_1; eauto; intuition idtac. destruct a1. Tactics.destruct_one_match; eauto using Proper_load . } + Qed. + + Global Instance Proper_list_map {A B} : + Proper ((pointwise_relation _ (pointwise_relation _ (pointwise_relation _ Basics.impl ==> Basics.impl))) ==> pointwise_relation _ (pointwise_relation _ (pointwise_relation _ Basics.impl ==> Basics.impl))) (MetricWeakestPrecondition.list_map (A:=A) (B:=B)). + Proof using. + clear. + cbv [Proper respectful pointwise_relation Basics.impl]; ind_on (list A); + cbn in *; intuition (try typeclasses eauto with core). + eapply H; eauto. destruct a2. eapply IHa; eauto. destruct a2; eauto. + Qed. + + Context {word_ok : word.ok word} {mem_ok : map.ok mem}. + Context {locals_ok : map.ok locals}. + Context {ext_spec_ok : Semantics.ext_spec.ok ext_spec}. + + Global Instance Proper_cmd : + Proper + (pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ Basics.impl)))) ==> + Basics.impl))))))) MetricWeakestPrecondition.cmd. + Proof using ext_spec_ok locals_ok mem_ok word_ok. + pose proof I. (* to keep naming *) + cbv [Proper respectful pointwise_relation Basics.flip Basics.impl]; ind_on Syntax.cmd.cmd; + cbn in *; cbv [dlet.dlet] in *; intuition (try typeclasses eauto with core). + { destruct H1 as (?&?&?&?). eexists. eexists. split. + 1: eapply Proper_expr. + 1: cbv [pointwise_relation Basics.impl]; intuition eauto 2. + all: eauto. } + { destruct H1 as (?&?&?&?). eexists. eexists. split. + { eapply Proper_expr. + { cbv [pointwise_relation Basics.impl]; intuition eauto 2. } + { eauto. } } + { destruct H2 as (?&?&?&?). eexists. eexists. split. + { eapply Proper_expr. + { cbv [pointwise_relation Basics.impl]; intuition eauto 2. } + { eauto. } } + { eapply Proper_store; eauto; cbv [pointwise_relation Basics.impl]; eauto. } } } + + { eapply H1. 2: eapply H3; eassumption. intros ? ? ? ? (?&?&?&?&?). eauto 7. } + { destruct H1 as (?&?&?&?). eexists. eexists. split. + { eapply Proper_expr. + { cbv [pointwise_relation Basics.impl]; intuition eauto 2. } + { eauto. } } + { intuition eauto 6. } } + { eapply MetricSemantics.exec.weaken; eassumption. } + { destruct H1 as (?&?&?&?). eexists. eexists. split. + { eapply Proper_list_map; eauto; try exact H4; cbv [respectful pointwise_relation Basics.impl]; intuition eauto 2. + eapply Proper_expr; eauto. } + { eapply MetricSemantics.weaken_call. 1: eassumption. cbv beta. + (* COQBUG (performance), measured in Coq 8.9: + "firstorder eauto" works, but takes ~100s and increases memory usage by 1.8GB. + On the other hand, the line below takes just 5ms *) + cbv beta; intros ? ? ? ? (?&?&?); eauto. } } + { destruct H1 as (?&?&?&?). eexists. eexists. split. + { eapply Proper_list_map; eauto; try exact H4; cbv [respectful pointwise_relation Basics.impl]. + { eapply Proper_expr; eauto. } + { eauto. } } + { destruct H2 as (mKeep & mGive & ? & ?). + exists mKeep. exists mGive. + split; [assumption|]. + eapply Semantics.ext_spec.weaken; [|solve[eassumption]]. + intros ? ? (?&?&?); eauto 10. } } + Qed. + + Global Instance Proper_call : + Proper ( + (pointwise_relation _ ( + (pointwise_relation _ ( + (pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ Basics.impl)))) ==> + Basics.impl)))))))))) MetricWeakestPrecondition.call. + Proof using word_ok mem_ok locals_ok ext_spec_ok. + cbv [Proper respectful pointwise_relation Basics.impl]. + intros. eapply MetricSemantics.weaken_call; eassumption. + Qed. + +Global Instance Proper_program : + Proper ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + pointwise_relation _ ( + (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ (pointwise_relation _ Basics.impl)))) ==> + Basics.impl))))))) MetricWeakestPrecondition.program. + Proof using word_ok mem_ok locals_ok ext_spec_ok. + cbv [Proper respectful pointwise_relation Basics.impl MetricWeakestPrecondition.program]; intros. + eapply Proper_cmd; + cbv [Proper respectful pointwise_relation Basics.flip Basics.impl MetricWeakestPrecondition.func]; + try solve [typeclasses eauto with core]. + Qed. + + From coqutil Require Import Datatypes.Prod. + Ltac t := + repeat match goal with + | _ => progress inversion_prod + | |- forall _, _ => progress intros + | H: exists _, _ |- _ => destruct H + | H: and _ _ |- _ => destruct H + | H: eq _ ?y |- _ => subst y + | H: False |- _ => destruct H + | _ => progress cbn in * + | _ => progress cbv [dlet.dlet MetricWeakestPrecondition.dexpr MetricWeakestPrecondition.dexprs MetricWeakestPrecondition.store] in * + end; eauto. + +Lemma expr_sound m l e mc post (H : MetricWeakestPrecondition.expr m l e mc post) + : exists v mc', MetricSemantics.eval_expr m l e mc = Some (v, mc') /\ post (v, mc'). +Proof using BW ext_spec ext_spec_ok locals +locals_ok mem mem_ok width word word_ok. + ind_on Syntax.expr; t. { destruct H. destruct H. eexists. eexists. rewrite H. eauto. } + { eapply IHe in H; t. cbv [MetricWeakestPrecondition.load] in H0; t. rewrite H. rewrite H0. eauto. } + { eapply IHe in H; t. cbv [MetricWeakestPrecondition.load] in H0; t. rewrite H. rewrite H0. + eexists. eexists. split; eauto. } + { eapply IHe1 in H; t. eapply IHe2 in H0; t. rewrite H, H0; eauto. } + { eapply IHe1 in H; t. rewrite H. Tactics.destruct_one_match. + { apply IHe3; t. } + { eapply IHe2 in H0; t. } } + Qed. + + Import ZArith coqutil.Tactics.Tactics. + + Lemma expr_complete: forall m l e mc v mc', + MetricSemantics.eval_expr m l e mc = Some (v, mc') -> + MetricWeakestPrecondition.dexpr m l e mc (v, mc'). + Proof using word_ok. + induction e; cbn; intros. + - inversion_clear H. reflexivity. + - eexists; eexists; destruct (map.get l x); try inversion H; try reflexivity. + - repeat (destruct_one_match_hyp; try discriminate; []). + eapply Proper_expr. + 2: { eapply IHe. rewrite E. reflexivity. } + intros (addr, oldmc) ?. apply pair_equal_spec in H0; destruct H0. + subst r m0. unfold MetricWeakestPrecondition.load. eexists; split; eauto. + apply Option.eq_of_eq_Some in H. auto. + - repeat (destruct_one_match_hyp; try discriminate; []). + eapply Proper_expr. + 2: { eapply IHe. rewrite E. reflexivity. } + intros (addr, oldmc) ?. apply pair_equal_spec in H0; destruct H0. + subst r m0. unfold MetricWeakestPrecondition.load. eexists; split; eauto. + apply Option.eq_of_eq_Some in H; auto. + - repeat (destruct_one_match_hyp; try discriminate; []). + eapply Proper_expr. + 2: { eapply IHe1. rewrite E. reflexivity. } + intros (v1, oldmc1) ?. apply pair_equal_spec in H0; destruct H0. + subst r m0. + eapply Proper_expr. + 2: { eapply IHe2. rewrite E0. reflexivity. } + intros (v2, oldmc2) ?. apply pair_equal_spec in H0; destruct H0. + subst r0 m1. congruence. + - repeat (destruct_one_match_hyp; try discriminate; []). + eapply Proper_expr. + 2: { eapply IHe1. rewrite E. reflexivity. } + intros (vc, oldmc) ?. apply pair_equal_spec in H0; destruct H0. + subst r m0. + destr (word.eqb vc (word.of_Z 0)). + + eapply IHe3. eassumption. + + eapply IHe2. eassumption. + Qed. + + +Lemma sound_args : forall m l args mc P, + MetricWeakestPrecondition.list_map (MetricWeakestPrecondition.expr m l) args mc P -> + exists x mc', MetricSemantics.eval_call_args m l args mc = Some (x, mc') /\ P (x, mc'). +Proof using BW ext_spec ext_spec_ok locals locals_ok mem mem_ok +width word word_ok. + induction args; cbn; repeat (subst; t). + eapply expr_sound in H; t; rewrite H. + eapply IHargs in H0; t; rewrite H0. + eauto. + Qed. + + Lemma sound_getmany l a mc P : + MetricWeakestPrecondition.list_map (MetricWeakestPrecondition.get l) a mc P + -> exists vs mc', map.getmany_of_list l a = Some vs /\ P (vs, mc'). + Proof. + cbv [map.getmany_of_list] in *. + revert P l mc; induction a; cbn; repeat (subst; t). + cbv [MetricWeakestPrecondition.get] in H; t. + epose proof (IHa _ l _ _); clear IHa; t. + rewrite H. erewrite H1. eexists; eexists; split; eauto. + Unshelve. + 3: exact H0. + all: cbv [respectful pointwise_relation Basics.impl MetricWeakestPrecondition.get]; intros; cbv beta; t. + Qed. + + Local Hint Constructors MetricSemantics.exec : core. + Lemma sound_cmd e c t m l mc post (H: MetricWeakestPrecondition.cmd e c t m l mc post) : MetricSemantics.exec e c t m l mc post. + Proof. + ind_on Syntax.cmd; repeat (t; try match reverse goal with H: MetricWeakestPrecondition.expr _ _ _ _ _ |- _ => eapply expr_sound in H end). + { destruct (BinInt.Z.eq_dec (word.unsigned x) 0) as [|]; t. } + { inversion H0. t. eapply sound_args in H; t. } + { eapply sound_args in H; t. } + Qed. + + Lemma weaken_cmd: forall e c t m l mc (post1 post2: _->_->_->_->Prop), + MetricWeakestPrecondition.cmd e c t m l mc post1 -> + (forall t m l mc, post1 t m l mc -> post2 t m l mc) -> + MetricWeakestPrecondition.cmd e c t m l mc post2. + Proof. + intros. + eapply Proper_cmd. 2: eassumption. + cbv [RelationClasses.Reflexive Morphisms.pointwise_relation + Morphisms.respectful Basics.impl]. + assumption. + Qed. + + Lemma complete_args : forall m l args mc vs, + MetricSemantics.eval_call_args m l args mc = Some vs -> + MetricWeakestPrecondition.dexprs m l args mc vs. + Proof using word_ok. + induction args; cbn; repeat (subst; t). + 1: inversion H; reflexivity. + destruct_one_match_hyp. 2: discriminate. + destruct_one_match_hyp. + destruct_one_match_hyp. 2: discriminate. + case p in *; inversion_clear H. + eapply Proper_expr. 2: eapply expr_complete. 2: eassumption. + intros x ?. subst x. + eapply Proper_list_map. 3: { eapply IHargs. eassumption. } + { eapply Proper_expr. } + { intros ? ?. subst. reflexivity. } + Qed. + + Lemma complete_cmd: forall e c t m l mc post, + MetricSemantics.exec e c t m l mc post -> + MetricWeakestPrecondition.cmd e c t m l mc post. + Proof. + induction 1. + { eassumption. } + { eapply expr_complete in H. eexists _, _. split. 1: exact H. + eassumption. } + { eauto. } + { eapply expr_complete in H. + eapply expr_complete in H0. + eexists _, _. split. 1: eassumption. + eexists _, _. split. 1: eassumption. + eexists. eauto. } + { split. 1: assumption. + intros * HA HSp. specialize H1 with (1 := HA) (2 := HSp). + unfold dlet.dlet. eapply weaken_cmd. 1: eapply H1. cbv beta. + clear. intros * (? & ? & ? & ? & ?). eauto 8. } + { eexists _, _; cbv[dlet.dlet]; ssplit; intros; eauto using expr_complete; congruence. } + { eexists _, _; cbv[dlet.dlet]; ssplit; intros; eauto using expr_complete; congruence. } + { cbn. eapply weaken_cmd. + { eapply IHexec. } + cbv beta. intros. + eapply H1. eassumption. } + { cbn. eapply MetricSemantics.exec.while_false; eauto. } + { rename IHexec into IH1, H3 into IH2. + cbn. eapply MetricSemantics.exec.while_true; eassumption. } + { cbn. eexists _, _. split. + { eapply complete_args. eassumption. } + unfold MetricSemantics.call. do 4 eexists. 1: eassumption. do 2 eexists. 1: eassumption. + eapply MetricSemantics.exec.weaken. + { eassumption. } + cbv beta. intros. + specialize H3 with (1 := H4). destruct H3 as (retvs & G & ? & ? & ?). eauto 8. } + { cbn. eexists _, _. split. + { eapply complete_args. eassumption. } + eexists _, _. split. 1: eassumption. + eapply Semantics.ext_spec.weaken. 2: eassumption. + intros m0 args0 Hmid. specialize H2 with (1 := Hmid). destruct H2 as (? & ? & ?). + eauto 8. } + Qed. + + Lemma start_func: forall e fname fimpl t m args mc post, + map.get e fname = Some fimpl -> + MetricWeakestPrecondition.func e fimpl t m args mc post -> + MetricWeakestPrecondition.call e fname t m args mc post. + Proof. + intros * G. destruct fimpl as [[argnames retnames] body]. intros (? & ? & ?). + do 4 eexists. 1: eassumption. do 2 eexists. 1: eassumption. eapply sound_cmd. + eapply weaken_cmd. 1: eassumption. cbv beta. intros. + edestruct sound_getmany as (?&?&?&?); eauto. + Qed. + + (** Ad-hoc lemmas here? *) + + Import bedrock2.Syntax bedrock2.MetricSemantics bedrock2.MetricWeakestPrecondition. + + Lemma interact_nomem call action binds arges t m l post + mc args mc' (Hargs : dexprs m l arges mc (args, mc')) + (Hext : ext_spec t map.empty binds args (fun mReceive (rets : list word) => + mReceive = map.empty /\ + exists l0 : locals, map.putmany_of_list_zip action rets l = Some l0 /\ + post (cons (map.empty, binds, args, (map.empty, rets)) t) m l0 (MetricCosts.cost_interact MetricCosts.PreSpill mc'))) + : MetricWeakestPrecondition.cmd call (cmd.interact action binds arges) t m l mc post. + Proof using word_ok mem_ok ext_spec_ok. + exists args, mc'; split; [exact Hargs|]. + exists m. + exists map.empty. + split; [eapply Properties.map.split_empty_r; exact eq_refl|]. + eapply Semantics.ext_spec.weaken; [|eapply Hext]; intros ? ? [? [? []]]. subst a; subst. + eexists; split; [eassumption|]. + intros. eapply Properties.map.split_empty_r in H. subst. assumption. + Qed. + + (* + Lemma intersect_expr: forall m l e mc (post1 post2: word * MetricLog -> Prop), + MetricWeakestPrecondition.expr m l e mc post1 -> + MetricWeakestPrecondition.expr m l e mc post2 -> + MetricWeakestPrecondition.expr m l e mc (fun v => post1 v /\ post2 v). + Proof using word_ok. Admitted. + + Lemma dexpr_expr (m : mem) l e mc P + (H : MetricWeakestPrecondition.expr m l e mc P) + : exists v, MetricWeakestPrecondition.dexpr m l e mc v /\ P v. + Proof using word_ok. Admitted. + *) + +End MetricWeakestPrecondition. diff --git a/bedrock2/src/bedrock2/Notations.v b/bedrock2/src/bedrock2/Notations.v new file mode 100644 index 000000000..c8176b6e2 --- /dev/null +++ b/bedrock2/src/bedrock2/Notations.v @@ -0,0 +1,21 @@ +Require Import coqutil.Macros.subst. + +Notation "' x <- a | y ; f" := + (match a with + | x => f + | _ => y + end) + (right associativity, at level 70, x pattern). + +Notation "'bind_ex' x <- a ; f" := + (subst! a for a' in exists x, a' x /\ f) + (only parsing, right associativity, at level 60, f at level 200). +Notation "'bind_ex_Some' x <- a ; f" := + (subst! a for a' in exists x, a' = Some x /\ f) + (only parsing, right associativity, at level 60, f at level 200). +Notation "'bind_ex_pair' ( x , y ) <- a ; f" := + (subst! a for a' in exists x y, a' (x, y) /\ f) + (only parsing, right associativity, at level 60, f at level 200). +Notation "'bind_eq' x <- a ; f" := + (subst! a for a' in forall x, x = a' -> f) + (only parsing, right associativity, at level 60, f at level 200). diff --git a/bedrock2/src/bedrock2Examples/bsearch.v b/bedrock2/src/bedrock2Examples/bsearch.v index 2e452db74..32b5ffa52 100644 --- a/bedrock2/src/bedrock2Examples/bsearch.v +++ b/bedrock2/src/bedrock2Examples/bsearch.v @@ -112,7 +112,7 @@ Proof. repeat apply conj; auto; []. (* postcondition *) letexists. split. { exact eq_refl. } - { auto. } + { repeat straightline. } Unshelve. all: exact (word.of_Z 0). diff --git a/bedrock2/src/bedrock2Examples/metric_ipow.v b/bedrock2/src/bedrock2Examples/metric_ipow.v new file mode 100644 index 000000000..2c931529b --- /dev/null +++ b/bedrock2/src/bedrock2Examples/metric_ipow.v @@ -0,0 +1,204 @@ +Require Import Coq.ZArith.ZArith coqutil.Z.div_mod_to_equations. +Require Import bedrock2.NotationsCustomEntry. +Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. +Import Syntax BinInt String List.ListNotations ZArith. +Require Import coqutil.Z.Lia. +Local Open Scope string_scope. Local Open Scope Z_scope. Local Open Scope list_scope. + +Definition ipow := func! (x, e) ~> ret { + ret = $1; + while (e) { + if (e & $1) { ret = ret * x }; + e = e >> $1; + x = x * x + } +}. + +From bedrock2 Require Import BasicC64Semantics MetricWeakestPrecondition MetricProgramLogic. +From bedrock2 Require Import MetricLoops. +From coqutil Require Import Word.Properties Word.Interface Tactics.letexists. +Import Interface.word. + +Definition initCost := {| instructions := 12; stores := 2; loads := 13; jumps := 0 |}. +Definition iterCost := {| instructions := 76; stores := 16; loads := 98; jumps := 2 |}. +Definition endCost := {| instructions := 6; stores := 1; loads := 9; jumps := 1 |}. + +Definition msb z := match z with + | Zpos _ => Z.log2 z + 1 + | _ => 0 + end. + +#[export] Instance spec_of_ipow : spec_of "ipow" := + fnspec! "ipow" x e ~> v, + { requires t m mc := True; + ensures t' m' mc' := unsigned v = unsigned x ^ unsigned e mod 2^64 /\ + (mc' - mc <= initCost + (msb (word.unsigned e)) * iterCost + endCost)%metricsH + }. + +Module Z. + Lemma pow_mod x n m (Hnz: m <> 0) : (x mod m)^n mod m = x^n mod m. + Proof. + revert n. + eapply Z.order_induction_0; intros. + { intros ???; subst; split; auto. } + { rewrite 2Z.pow_0_r; trivial. } + { rewrite 2Z.pow_succ_r by trivial. + rewrite <-Z.mul_mod_idemp_r by trivial. + multimatch goal with H: _ |- _ => rewrite H end; + rewrite Z.mul_mod_idemp_l, Z.mul_mod_idemp_r; solve[trivial]. } + { rewrite 2Z.pow_neg_r; trivial. } + Qed. + + Lemma mod2_nonzero x : x mod 2 <> 0 -> x mod 2 = 1. + Proof. Z.div_mod_to_equations. blia. Qed. + + Lemma land_1_r x : Z.land x 1 = x mod 2. + Proof. + change 1 with (Z.ones 1) in *. + rewrite Z.land_ones in * by discriminate. + exact eq_refl. + Qed. +End Z. + +Require Import bedrock2.AbsintWordToZ coqutil.Z.Lia. + +Ltac t := + repeat match goal with x := _ |- _ => subst x end; + repeat match goal with |- context [word.unsigned ?e] => progress (idtac; let H := rbounded (word.unsigned e) in idtac) end; + repeat match goal with G: context [word.unsigned ?e] |- _ => progress (idtac; let H := rbounded (word.unsigned e) in idtac) end; + repeat match goal with |- context [word.unsigned ?e] => progress (idtac; let H := unsigned.zify_expr e in try rewrite H) end; + repeat match goal with G: context [word.unsigned ?e] |- _ => progress (idtac; let H := unsigned.zify_expr e in try rewrite H in G) end; + repeat match goal with H: absint_eq ?x ?x |- _ => clear H end; + cbv [absint_eq] in *. + +Lemma msb_shift z : 0 < z -> msb (z / 2) = msb z - 1. +Proof. + intro. + case (z / 2) eqn:Hdiv. + - enough (H1 : z = 1) by (rewrite H1; easy). + enough (z = z mod 2) by (Z.div_mod_to_equations; blia). + rewrite (Z.div_mod z 2) by blia. + rewrite Hdiv. + cbn. + rewrite Zmod_mod. + reflexivity. + - rewrite <- Z.div2_div in Hdiv. + rewrite (Zdiv2_odd_eqn z). + rewrite Hdiv. + rewrite <- Pos2Z.inj_mul. + case (Z.odd z); + [rewrite <- Pos2Z.inj_add | rewrite Z.add_0_r]; + unfold msb; + rewrite Z.add_simpl_r; + [rewrite Pos2Z.inj_add |]; rewrite Pos2Z.inj_mul; + [rewrite Z.log2_succ_double | rewrite Z.log2_double]; + blia. + - pose proof (Zlt_neg_0 p) as Hneg. + rewrite <- Hdiv in Hneg. + Z.div_mod_to_equations. + blia. +Qed. + +Ltac s := unfold initCost, iterCost, endCost in *; + cost_unfold; + cbn in *; + solve_MetricLog. + +Lemma ipow_ok : program_logic_goal_for_function! ipow. +Proof. + repeat straightline. + match goal with H : True |- _ => clear H end. + + refine ((MetricLoops.tailrec + (* types of ghost variables*) HList.polymorphic_list.nil + (* program variables *) (["e";"ret";"x"] : list String.string)) + (fun v t m e ret x mc => PrimitivePair.pair.mk (v = word.unsigned e) (* precondition *) + (fun T M E RET X MC => T = t /\ M = m /\ (* postcondition *) + word.unsigned RET = word.unsigned ret * word.unsigned x ^ word.unsigned e mod 2^64 /\ + (MC - mc <= msb (word.unsigned e) * iterCost + endCost)%metricsH)) + (fun n m => 0 <= n < m) (* well_founded relation *) + _ _ _ _ _); + (* TODO wrap this into a tactic with the previous refine *) + cbn [HList.hlist.foralls HList.tuple.foralls + HList.hlist.existss HList.tuple.existss + HList.hlist.apply HList.tuple.apply + HList.hlist + List.repeat Datatypes.length + HList.polymorphic_list.repeat HList.polymorphic_list.length + PrimitivePair.pair._1 PrimitivePair.pair._2] in *. + + { repeat straightline. } + { exact (Z.lt_wf _). } + { repeat straightline. } (* init precondition *) + { (* loop test *) + repeat straightline; try show_program. + { (* loop body *) + eexists; eexists; split; [repeat straightline|]. (* if condition evaluation *) + split. (* if cases, path-blasting *) + { + repeat (straightline || (split; trivial; [])). 2: split. all:t. + { (* measure decreases *) + set (word.unsigned x0) in *. (* WHY does blia need this? *) + Z.div_mod_to_equations. blia. } + { (* invariant preserved *) + rewrite H3; clear H3. rename H0 into Hbit. + change (1+1) with 2 in *. + eapply Z.mod2_nonzero in Hbit. + epose proof (Z.div_mod _ 2 ltac:(discriminate)) as Heq; rewrite Hbit in Heq. + rewrite Heq at 2; clear Hbit Heq. + (* rewriting with equivalence modulo ... *) + rewrite !word.unsigned_mul. + unfold word.wrap. + rewrite ?Z.mul_mod_idemp_l by discriminate. + rewrite <-(Z.mul_mod_idemp_r _ (_^_)), Z.pow_mod by discriminate. + rewrite ?Z.pow_add_r by (pose proof word.unsigned_range x0; Z.div_mod_to_equations; blia). + rewrite ?Z.pow_twice_r, ?Z.pow_1_r, ?Z.pow_mul_l. + rewrite Z.mul_mod_idemp_r by discriminate. + f_equal; ring. } + { (* metrics correct *) + rewrite msb_shift in H4 by blia. + rewrite MetricArith.mul_sub_distr_r in H4. + rewrite <- MetricArith.add_sub_swap in H4. + rewrite <- MetricArith.le_add_le_sub_r in H4. + eapply MetricArith.le_trans with (2 := H4). + s. + } + } + { + repeat (straightline || (split; trivial; [])). 2: split. all: t. + { (* measure decreases *) + set (word.unsigned x0) in *. (* WHY does blia need this? *) + Z.div_mod_to_equations; blia. } + { (* invariant preserved *) + rewrite H3; clear H3. rename H0 into Hbit. + change (1+1) with 2 in *. + epose proof (Z.div_mod _ 2 ltac:(discriminate)) as Heq; rewrite Hbit in Heq. + rewrite Heq at 2; clear Hbit Heq. + (* rewriting with equivalence modulo ... *) + rewrite !word.unsigned_mul, ?Z.mul_mod_idemp_l by discriminate. + cbv [word.wrap]. + rewrite <-(Z.mul_mod_idemp_r _ (_^_)), Z.pow_mod by discriminate. + rewrite ?Z.add_0_r, Z.pow_twice_r, ?Z.pow_1_r, ?Z.pow_mul_l. + rewrite Z.mul_mod_idemp_r by discriminate. + f_equal; ring. } + { (* metrics correct *) + rewrite msb_shift in H4 by blia. + s. + } + } + } + { (* postcondition *) + rewrite H, Z.pow_0_r, Z.mul_1_r, word.wrap_unsigned. + split; [reflexivity|]. + unfold msb; subst brmc. + s. + } + } + + repeat straightline. + + repeat (split || letexists || t || trivial). + { setoid_rewrite H1; setoid_rewrite Z.mul_1_l; trivial. } + all: s. +Qed. diff --git a/compiler/src/compiler/DeadCodeElim.v b/compiler/src/compiler/DeadCodeElim.v index db0ad51ef..eeb9c86ec 100644 --- a/compiler/src/compiler/DeadCodeElim.v +++ b/compiler/src/compiler/DeadCodeElim.v @@ -10,10 +10,12 @@ Require Import coqutil.Datatypes.ListSet. Local Notation var := String.string (only parsing). Require Import compiler.util.Common. Require Import bedrock2.MetricLogging. -Require Import coqutil.Tactics.fwd. +Require Import bedrock2.MetricCosts. (* below only for of_list_list_diff *) Require Import compiler.DeadCodeElimDef. +Local Notation exec := (exec PreSpill isRegStr). + Section WithArguments1. Context {width: Z}. Context {BW: Bitwidth.Bitwidth width }. @@ -22,7 +24,7 @@ Section WithArguments1. Context {mem: map.map word (Init.Byte.byte : Type) } {mem_ok : map.ok mem } . Context {locals: map.map string word } {locals_ok : map.ok locals }. Context {ext_spec : Semantics.ExtSpec } {ext_spec_ok: Semantics.ext_spec.ok ext_spec } . - + Lemma agree_on_put_existsb_false: forall used_after x (l: locals) lL, map.agree_on (diff (of_list used_after) (singleton_set x)) l lL @@ -36,7 +38,7 @@ Section WithArguments1. intros. propositional idtac. eapply existsb_of_list in H1. rewrite H1 in H0. - discriminate. + discriminate. Qed. Ltac subset_union_solve := @@ -71,7 +73,7 @@ Section WithArguments1. [ idtac | eapply H ]; subset_union_solve | H: map.agree_on ?s ?x ?y |- map.agree_on _ ?y ?x => - eapply agree_on_comm; agree_on_solve + eapply agree_on_comm; agree_on_solve | H: map.agree_on ?s ?mH ?mL, H1: map.putmany_of_list_zip ?lk ?lv ?mH = Some ?mH', H2: map.putmany_of_list_zip ?lk ?lv ?mL = Some ?mL' @@ -107,6 +109,7 @@ Section WithArguments1. rewrite ListSet.of_list_removeb end. + Ltac mcsolve := eexists; split; [|split; cycle 1; [eauto|FlatImp.scost_hammer]]; try assumption. Lemma dce_correct_aux : forall eH eL, @@ -115,7 +118,7 @@ Section WithArguments1. exec eH sH t m lH mcH postH -> forall used_after lL mcL, map.agree_on (of_list (live sH used_after)) lH lL -> - exec eL (dce sH used_after) t m lL mcL (compile_post used_after postH). + exec eL (dce sH used_after) t m lL mcL (compile_post mcH mcL used_after postH). Proof. induction 2; match goal with @@ -139,10 +142,9 @@ Section WithArguments1. * eapply H5. * intros. unfold compile_post. - exists l'. eexists. split. - -- agree_on_solve. repeat listset_to_set. - subset_union_solve. - -- eauto. + exists l'. mcsolve. + agree_on_solve. repeat listset_to_set. + subset_union_solve. - intros. eapply @exec.call; try solve [ eassumption ]. + unfold dce_functions, dce_function in *. @@ -151,37 +153,37 @@ Section WithArguments1. fwd. eassumption. + erewrite agree_on_getmany. * eapply H1. - * listset_to_set. agree_on_solve. + * listset_to_set. agree_on_solve. + eapply IHexec. eapply agree_on_refl. + intros. unfold compile_post in *. - fwd. eapply H4 in H6p1. fwd. + fwd. eapply H4 in H6p2. fwd. let Heq := fresh in - pose proof H6p1p1 as Heq; - eapply map.putmany_of_list_zip_sameLength, map.sameLength_putmany_of_list in H6p1p1. fwd. + pose proof H6p2p1 as Heq; + eapply map.putmany_of_list_zip_sameLength, map.sameLength_putmany_of_list in H6p2p1. fwd. exists retvs. eexists. repeat split. * erewrite agree_on_getmany. - -- eapply H6p1p0. + -- eapply H6p2p0. -- listset_to_set. agree_on_solve. - * eapply H6p1p1. - * do 2 eexists. split; [ | eassumption ]. + * eapply H6p2p1. + * eexists. mcsolve. agree_on_solve. repeat listset_to_set. subset_union_solve. - intros. - eapply agree_on_find in H3; fwd. + eapply agree_on_find in H3; fwd. destr (existsb (eqb x) used_after); fwd. + eapply @exec.load. - * rewrite <- H3p1. eassumption. + * rewrite <- H3p1. eassumption. * eauto. * unfold compile_post. - exists (map.put l x v); eexists; split; [ | eassumption ]. + exists (map.put l x v); mcsolve. repeat listset_to_set. agree_on_solve. + eapply @exec.skip. * unfold compile_post. - exists (map.put l x v); eexists; split; [ | eassumption ]. + exists (map.put l x v); mcsolve. repeat listset_to_set. agree_on_solve. - intros. repeat listset_to_set. @@ -193,17 +195,17 @@ Section WithArguments1. + erewrite <- H4p0; eauto. unfold elem_of; destr (a =? v)%string; [ eapply in_eq | eapply in_cons, in_eq ]. + eassumption. - + unfold compile_post. exists l; eexists; split; eassumption. + + unfold compile_post. exists l; mcsolve. - intros. eapply agree_on_find in H4; fwd. destr (existsb (eqb x) used_after); fwd. + eapply @exec.inlinetable; eauto. - * rewrite <- H4p1. eassumption. - * unfold compile_post; do 2 eexists; split ; [ | eassumption ]. + * rewrite <- H4p1. eassumption. + * unfold compile_post; eexists; mcsolve. repeat listset_to_set; agree_on_solve. + eapply @exec.skip; eauto. unfold compile_post. - do 2 eexists; split; [ | eassumption ]. + eexists; mcsolve. repeat listset_to_set; agree_on_solve. - intros. repeat listset_to_set. @@ -218,17 +220,17 @@ Section WithArguments1. ++ eassumption. ++ split. ** eassumption. - ** do 2 eexists; split; [ eassumption | eapply H6p1p2 ]. + ** eexists; mcsolve; eauto. - intros. destr (existsb (eqb x) used_after). + eapply @exec.lit. unfold compile_post. repeat listset_to_set. - do 2 eexists; split; [ | eassumption ]. + eexists; mcsolve. agree_on_solve. + eapply @exec.skip. unfold compile_post. repeat listset_to_set. - do 2 eexists; split; [ | eassumption ]. + eexists; mcsolve. agree_on_solve. - destr z. + intros. repeat listset_to_set. @@ -246,23 +248,23 @@ Section WithArguments1. ++ eapply in_eq. ++ eapply in_cons, in_eq. -- unfold compile_post. - do 2 eexists; split; [ | eassumption ]. + eexists; mcsolve. agree_on_solve. * eapply @exec.skip. unfold compile_post. - do 2 eexists; split; [ | eassumption ]. + eexists; mcsolve. agree_on_solve. + intros. - eapply agree_on_find in H3; fwd. + eapply agree_on_find in H3; fwd. destr (existsb (eqb x) used_after). * eapply @exec.op. - -- rewrite <- H3p1. eassumption. + -- rewrite <- H3p1. eassumption. -- simpl. constructor. - -- unfold compile_post. simpl in *. inversion H1. fwd. do 2 eexists; split; [ | eassumption ]. + -- unfold compile_post. simpl in *. inversion H1. fwd. eexists; mcsolve. repeat listset_to_set. agree_on_solve. * eapply @exec.skip. unfold compile_post. - do 2 eexists; split ; [ | eassumption ]. + eexists; mcsolve. repeat listset_to_set. agree_on_solve. - intros. @@ -270,12 +272,12 @@ Section WithArguments1. repeat listset_to_set. destr (existsb (eqb x) used_after). { eapply @exec.set. - - rewrite <- H2p1; eassumption. - - unfold compile_post. do 2 eexists; split; [ | eassumption ]. + - rewrite <- H2p1; eassumption. + - unfold compile_post. eexists; mcsolve. agree_on_solve. } { eapply @exec.skip. - - unfold compile_post. do 2 eexists; split; [ | eassumption ]. + - unfold compile_post. eexists; mcsolve. agree_on_solve. } - intros. @@ -285,7 +287,10 @@ Section WithArguments1. eapply @exec.if_true. + erewrite agree_on_eval_bcond; [ eassumption | ]. pose agree_on_comm; eauto. - + eauto. + + eapply @exec.weaken; [eauto|]. + unfold compile_post. + intros * (?&?&?&?&?). + eexists. mcsolve. - intros. repeat listset_to_set. eapply agree_on_union in H2; fwd. @@ -293,14 +298,17 @@ Section WithArguments1. eapply @exec.if_false. + erewrite agree_on_eval_bcond; [ eassumption | ]. pose agree_on_comm; eauto. - + eauto. + + eapply @exec.weaken; [eauto|]. + unfold compile_post. + intros * (?&?&?&?&?). + eexists. mcsolve. - intros. cbn - [live]. rename IHexec into IH1. rename H6 into IH12. rename H4 into IH2. cbn - [live] in IH12. - eapply @exec.loop with (mid2 := compile_post (live (SLoop body1 cond body2) used_after) mid2). + eapply @exec.loop with (mid2 := compile_post mc mcL (live (SLoop body1 cond body2) used_after) mid2). { eapply IH1. eapply agree_on_subset. - let Heq := fresh in @@ -310,8 +318,8 @@ Section WithArguments1. } { intros. unfold compile_post in *. - repeat destr H4. - eapply H1 in H6. + repeat destr H4. destr H6. + eapply H1 in H8. erewrite agree_on_eval_bcond; [ eassumption | ]. eapply agree_on_comm. repeat listset_to_set. @@ -322,15 +330,12 @@ Section WithArguments1. } { intros. unfold compile_post in *. - repeat destr H4. - eapply H2 in H8. - - exists x. - eexists. - split. - + repeat listset_to_set. - eapply agree_on_subset; [ | eapply H4 ]. - subset_union_solve. - + eapply H8. + repeat destr H4. destr H8. + eapply H2 in H9. + - exists x. mcsolve. + repeat listset_to_set. + eapply agree_on_subset; [ | eapply H4 ]. + subset_union_solve. - erewrite agree_on_eval_bcond; [ eassumption | ]. repeat listset_to_set. eapply agree_on_subset; [ | eapply H4 ]. @@ -339,24 +344,30 @@ Section WithArguments1. { intros. unfold compile_post in *. - repeat destr H4. - eapply IH2. - - eapply H8. - - erewrite agree_on_eval_bcond; [ eassumption | ]. + repeat destr H4. destr H8. + assert (eval_bcond x cond = Some true) as Hbcond. + { erewrite agree_on_eval_bcond; [ eassumption | ]. repeat listset_to_set. eapply agree_on_subset; [ | eapply H4 ]. - subset_union_solve. + subset_union_solve. } + eapply @exec.weaken; [eapply IH2|]. + - eapply H9. + - exact Hbcond. - repeat listset_to_set. eapply agree_on_subset; [ | eapply H4 ]. subset_union_solve. + - cbv beta. intros * (?&?&?&?&?). + eexists. mcsolve. } { intros. unfold compile_post in *. - repeat destr H4. - eapply IH12. - - eapply H6. + repeat destr H4. destr H6. + eapply @exec.weaken; [eapply IH12|]. + - eapply H8. - eapply H4. + - cbv beta. intros * (?&?&?&?&?). + eexists. mcsolve. } - intros. eapply @exec.seq. @@ -367,9 +378,9 @@ Section WithArguments1. -- eassumption. -- eassumption. * unfold compile_post. intros. fwd. - do 2 eexists; split; eassumption. + eexists. mcsolve. - intros. eapply @exec.skip. - unfold compile_post. do 2 eexists; split; eassumption. + unfold compile_post. eexists. mcsolve. Qed. End WithArguments1. diff --git a/compiler/src/compiler/DeadCodeElimDef.v b/compiler/src/compiler/DeadCodeElimDef.v index 7ebf13a81..5e905c4c1 100644 --- a/compiler/src/compiler/DeadCodeElimDef.v +++ b/compiler/src/compiler/DeadCodeElimDef.v @@ -518,7 +518,7 @@ Section WithArguments1. Definition compile_post - used_after + mcH mcL used_after (postH: Semantics.trace -> mem -> locals -> MetricLog -> Prop) : Semantics.trace -> mem -> locals -> MetricLog -> Prop @@ -526,6 +526,7 @@ Section WithArguments1. (fun t' m' lL' mcL' => exists lH' mcH', map.agree_on (PropSet.of_list used_after) lH' lL' + /\ metricsLeq (mcL' - mcL) (mcH' - mcH) /\ postH t' m' lH' mcH'). Lemma agree_on_eval_bcond: diff --git a/compiler/src/compiler/FlatImp.v b/compiler/src/compiler/FlatImp.v index 72914d2ab..403c4443a 100644 --- a/compiler/src/compiler/FlatImp.v +++ b/compiler/src/compiler/FlatImp.v @@ -2,6 +2,7 @@ Require Import Coq.Bool.Bool. Require Import Coq.ZArith.ZArith. Require Import Coq.Lists.List. Import ListNotations. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import coqutil.Macros.unique. Require Import bedrock2.Memory. Require Import compiler.util.Common. @@ -280,6 +281,9 @@ Module exec. {env_ok: map.ok env} {ext_spec_ok: ext_spec.ok ext_spec}. + Variable (phase: compphase). + Variable (isReg: varname -> bool). + Variable (e: env). Local Notation metrics := MetricLog. @@ -293,6 +297,30 @@ Module exec. | Const co => Some (word.of_Z co) end. + (* Helper functions for computing costs of instructions *) + + Definition cost_SOp x y z mc := + cost_op (fun v => match v with | Var vo => isReg vo | Const _ => true end) + (Var x) (Var y) z mc. + + Definition cost_SIf bcond mc := + match bcond with + | CondBinary _ x y => cost_if isReg x (Some y) + | CondNez x => cost_if isReg x None + end mc. + + Definition cost_SLoop_true bcond mc := + match bcond with + | CondBinary _ x y => cost_loop_true isReg x (Some y) + | CondNez x => cost_loop_true isReg x None + end mc. + + Definition cost_SLoop_false bcond mc := + match bcond with + | CondBinary _ x y => cost_loop_false isReg x (Some y) + | CondNez x => cost_loop_false isReg x None + end mc. + (* alternative semantics which allow non-determinism *) Inductive exec: stmt varname -> @@ -308,92 +336,69 @@ Module exec. exists l', map.putmany_of_list_zip resvars resvals l = Some l' /\ forall m', map.split m' mKeep mReceive -> post (((mGive, action, argvals), (mReceive, resvals)) :: t) m' l' - (addMetricInstructions 1 - (addMetricStores 1 - (addMetricLoads 2 mc)))) -> + (cost_interact phase mc)) -> exec (SInteract resvars action argvars) t m l mc post | call: forall t m l mc binds fname args params rets fbody argvs st0 post outcome, map.get e fname = Some (params, rets, fbody) -> map.getmany_of_list l args = Some argvs -> map.putmany_of_list_zip params argvs map.empty = Some st0 -> - exec fbody t m st0 (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc)))) outcome -> + exec fbody t m st0 mc outcome -> (forall t' m' mc' st1, outcome t' m' st1 mc' -> exists retvs l', map.getmany_of_list st1 rets = Some retvs /\ map.putmany_of_list_zip binds retvs l = Some l' /\ - post t' m' l' (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc'))))) -> + post t' m' l' (cost_call phase mc')) -> exec (SCall binds fname args) t m l mc post - (* TODO think about a non-fixed bound on the cost of function preamble and postamble *) | load: forall t m l mc sz x a o v addr post, map.get l a = Some addr -> load sz m (word.add addr (word.of_Z o)) = Some v -> - post t m (map.put l x v) - (addMetricLoads 2 - (addMetricInstructions 1 mc)) -> + post t m (map.put l x v) (cost_load isReg x a mc)-> exec (SLoad sz x a o) t m l mc post | store: forall t m m' mc l sz a o addr v val post, map.get l a = Some addr -> map.get l v = Some val -> store sz m (word.add addr (word.of_Z o)) val = Some m' -> - post t m' l - (addMetricLoads 1 - (addMetricInstructions 1 - (addMetricStores 1 mc))) -> + post t m' l (cost_store isReg a v mc) -> exec (SStore sz a v o) t m l mc post | inlinetable: forall sz x table i v index t m l mc post, (* compiled riscv code uses x as a tmp register and this shouldn't overwrite i *) x <> i -> map.get l i = Some index -> load sz (map.of_list_word table) index = Some v -> - post t m (map.put l x v) - (addMetricLoads 4 - (addMetricInstructions 3 - (addMetricJumps 1 mc))) -> + post t m (map.put l x v) (cost_inlinetable isReg x i mc) -> exec (SInlinetable sz x table i) t m l mc post | stackalloc: forall t mSmall l mc x n body post, n mod (bytes_per_word width) = 0 -> (forall a mStack mCombined, anybytes a n mStack -> map.split mCombined mSmall mStack -> - exec body t mCombined (map.put l x a) (addMetricLoads 1 (addMetricInstructions 1 mc)) + exec body t mCombined (map.put l x a) mc (fun t' mCombined' l' mc' => exists mSmall' mStack', anybytes a n mStack' /\ map.split mCombined' mSmall' mStack' /\ - post t' mSmall' l' mc')) -> + post t' mSmall' l' (cost_stackalloc isReg x mc'))) -> exec (SStackalloc x n body) t mSmall l mc post | lit: forall t m l mc x v post, - post t m (map.put l x (word.of_Z v)) - (addMetricLoads 8 - (addMetricInstructions 8 mc)) -> + post t m (map.put l x (word.of_Z v)) (cost_lit isReg x mc) -> exec (SLit x v) t m l mc post | op: forall t m l mc x op y y' z z' post, map.get l y = Some y' -> lookup_op_locals l z = Some z' -> - post t m (map.put l x (interp_binop op y' z')) - (addMetricLoads 2 - (addMetricInstructions 2 mc)) -> + post t m (map.put l x (interp_binop op y' z')) (cost_SOp x y z mc) -> exec (SOp x op y z) t m l mc post | set: forall t m l mc x y y' post, map.get l y = Some y' -> - post t m (map.put l x y') - (addMetricLoads 1 - (addMetricInstructions 1 mc)) -> + post t m (map.put l x y') (cost_set isReg x y mc) -> exec (SSet x y) t m l mc post | if_true: forall t m l mc cond bThen bElse post, eval_bcond l cond = Some true -> - exec bThen t m l - (addMetricLoads 2 - (addMetricInstructions 2 - (addMetricJumps 1 mc))) post -> + exec bThen t m l (cost_SIf cond mc) post -> exec (SIf cond bThen bElse) t m l mc post | if_false: forall t m l mc cond bThen bElse post, eval_bcond l cond = Some false -> - exec bElse t m l - (addMetricLoads 2 - (addMetricInstructions 2 - (addMetricJumps 1 mc))) post -> + exec bElse t m l (cost_SIf cond mc) post -> exec (SIf cond bThen bElse) t m l mc post | loop: forall t m l mc cond body1 body2 mid1 mid2 post, (* This case is carefully crafted in such a way that recursive uses of exec @@ -406,10 +411,7 @@ Module exec. (forall t' m' l' mc', mid1 t' m' l' mc' -> eval_bcond l' cond = Some false -> - post t' m' l' - (addMetricLoads 1 - (addMetricInstructions 1 - (addMetricJumps 1 mc')))) -> + post t' m' l' (cost_SLoop_false cond mc')) -> (forall t' m' l' mc', mid1 t' m' l' mc' -> eval_bcond l' cond = Some true -> @@ -417,9 +419,7 @@ Module exec. (forall t'' m'' l'' mc'', mid2 t'' m'' l'' mc'' -> exec (SLoop body1 cond body2) t'' m'' l'' - (addMetricLoads 2 - (addMetricInstructions 2 - (addMetricJumps 1 mc''))) post) -> + (cost_SLoop_true cond mc'') post) -> exec (SLoop body1 cond body2) t m l mc post | seq: forall t m l mc s1 s2 mid post, exec s1 t m l mc mid -> @@ -451,12 +451,12 @@ Module exec. map.get e fname = Some (params, rets, fbody) -> map.getmany_of_list l args = Some argvs -> map.putmany_of_list_zip params argvs map.empty = Some st -> - exec fbody t m st (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc)))) + exec fbody t m st mc (fun t' m' st' mc' => exists retvs l', map.getmany_of_list st' rets = Some retvs /\ map.putmany_of_list_zip binds retvs l = Some l' /\ - post t' m' l' (addMetricInstructions 100 (addMetricJumps 100 (addMetricLoads 100 (addMetricStores 100 mc'))))) -> + post t' m' l' (cost_call phase mc')) -> exec (SCall binds fname args) t m l mc post. Proof. intros. eapply call; try eassumption. @@ -466,15 +466,15 @@ Module exec. Lemma loop_cps: forall body1 cond body2 t m l mc post, exec body1 t m l mc (fun t m l mc => exists b, eval_bcond l cond = Some b /\ - (b = false -> post t m l (addMetricLoads 1 (addMetricInstructions 1 (addMetricJumps 1 mc)))) /\ + (b = false -> post t m l (cost_SLoop_false cond mc)) /\ (b = true -> exec body2 t m l mc (fun t m l mc => exec (SLoop body1 cond body2) t m l - (addMetricLoads 2 (addMetricInstructions 2 (addMetricJumps 1 mc))) post))) -> + (cost_SLoop_true cond mc) post))) -> exec (SLoop body1 cond body2) t m l mc post. Proof. intros. eapply loop. 1: eapply H. all: cbv beta; intros; simp. - congruence. - - replace b with false in * by congruence. clear b. eauto. + - replace b with false in * by congruence. clear b. eauto. - replace b with true in * by congruence. clear b. eauto. - assumption. Qed. @@ -624,14 +624,17 @@ Section FlatImp2. {env_ok: map.ok env} {ext_spec_ok: ext_spec.ok ext_spec}. + Variable (phase: compphase). + Variable (isReg: varname -> bool). + Definition SimState: Type := trace * mem * locals * MetricLog. Definition SimExec(e: env)(c: stmt varname): SimState -> (SimState -> Prop) -> Prop := fun '(t, m, l, mc) post => - exec e c t m l mc (fun t' m' l' mc' => post (t', m', l', mc')). + exec phase isReg e c t m l mc (fun t' m' l' mc' => post (t', m', l', mc')). Lemma modVarsSound: forall e s initialT (initialSt: locals) initialM (initialMc: MetricLog) post, - exec e s initialT initialM initialSt initialMc post -> - exec e s initialT initialM initialSt initialMc + exec phase isReg e s initialT initialM initialSt initialMc post -> + exec phase isReg e s initialT initialM initialSt initialMc (fun finalT finalM finalSt _ => map.only_differ initialSt (modVars s) finalSt). Proof. induction 1; @@ -651,7 +654,7 @@ Section FlatImp2. - eapply exec.stackalloc; try eassumption. intros. eapply exec.weaken. - + eapply exec.intersect. + + eapply exec.intersect; try eassumption. * eapply H0; eassumption. * eapply H1; eassumption. + simpl. intros. simp. @@ -671,7 +674,7 @@ Section FlatImp2. + intros. simp. eauto. + intros. simp. simpl. map_solver locals_ok. + intros. simp. simpl in *. - eapply exec.intersect; [eauto|]. + eapply exec.intersect; try eassumption; [eauto|]. eapply exec.weaken. * eapply H3; eassumption. * simpl. intros. map_solver locals_ok. @@ -689,3 +692,20 @@ Section FlatImp2. Qed. End FlatImp2. + +(* various helper tactics extending the ones from MetricCosts *) + +Ltac scost_unfold := + unfold exec.cost_SOp, exec.cost_SIf, exec.cost_SLoop_true, exec.cost_SLoop_false in *; cost_unfold. + +Ltac scost_destr := + repeat match goal with + | x : operand |- _ => destr x + | x : bbinop _ |- _ => destr x + | x : bcond _ |- _ => destr x + | _ => cost_destr + end. + +Ltac scost_solve := scost_unfold; scost_destr; try solve_MetricLog. +Ltac scost_solve_piecewise := scost_unfold; scost_destr; try solve_MetricLog_piecewise. +Ltac scost_hammer := try solve [eauto 3 with metric_arith | scost_solve]. diff --git a/compiler/src/compiler/FlatToRiscvCommon.v b/compiler/src/compiler/FlatToRiscvCommon.v index 1fca28830..7cdb7c49b 100644 --- a/compiler/src/compiler/FlatToRiscvCommon.v +++ b/compiler/src/compiler/FlatToRiscvCommon.v @@ -45,6 +45,7 @@ Require Import compiler.RunInstruction. Require Import compiler.DivisibleBy4. Require Import compiler.MetricsToRiscv. Require Export compiler.regs_initialized. +Require Import bedrock2.MetricCosts. Require Import coqutil.Word.Interface. Local Hint Mode Word.Interface.word - : typeclass_instances. @@ -288,7 +289,8 @@ Section WithParameters. exists pos, map.get finfo f = Some pos /\ pos mod 4 = 0. Local Notation stmt := (stmt Z). - + Local Notation exec := (exec PostSpill isRegZ). + (* note: [e_impl_reduced] and [funnames] will shrink one function at a time each time we enter a new function body, to make sure functions cannot call themselves, while [e_impl] and [e_pos] remain the same throughout because that's mandated by diff --git a/compiler/src/compiler/FlatToRiscvFunctions.v b/compiler/src/compiler/FlatToRiscvFunctions.v index 5aa6f3189..84aacc06e 100644 --- a/compiler/src/compiler/FlatToRiscvFunctions.v +++ b/compiler/src/compiler/FlatToRiscvFunctions.v @@ -32,6 +32,7 @@ Require Import coqutil.Word.DebugWordEq. Require Import compiler.MemoryLayout. Require Import coqutil.Map.MapEauto. Require Import compiler.Registers. +Require Import bedrock2.MetricCosts. Import MetricLogging. @@ -203,6 +204,13 @@ Section Proofs. * eauto. Qed. + Lemma valid_FlatImp_var_isRegZ : forall x, + valid_FlatImp_var x -> isRegZ x = true. + Proof. + unfold valid_FlatImp_var, isRegZ; blia. + Qed. + Hint Resolve valid_FlatImp_var_isRegZ. + Ltac run1done := apply runsToDone; simpl_MetricRiscvMachine_get_set; @@ -212,7 +220,12 @@ Section Proofs. end; ssplit; simpl_word_exprs word_ok; match goal with | |- _ => solve_word_eq word_ok - | |- (_ <= _)%metricsL => MetricsToRiscv.solve_MetricLog + | |- (_ <= _)%metricsL => + scost_unfold; + repeat match goal with + | H : valid_FlatImp_var _ |- _ => apply valid_FlatImp_var_isRegZ in H; rewrite H in * + end; + MetricsToRiscv.solve_MetricLog | |- iff1 ?x ?x => reflexivity (* `exists stack_trash frame_trash, ...` from goodMachine *) | |- exists _ _, _ = _ /\ _ = _ /\ (_ * _)%sep _ => @@ -417,6 +430,13 @@ Section Proofs. Qed. + Local Notation exec := (exec PostSpill isRegZ). + + Definition cost_compile_spec mc := + Platform.MetricLogging.addMetricInstructions 95 + (Platform.MetricLogging.addMetricJumps 95 + (Platform.MetricLogging.addMetricLoads 95 + (Platform.MetricLogging.addMetricStores 95 mc))). Lemma compile_function_body_correct: forall (e_impl_full : env) m l mc (argvs : list word) (st0 : locals) (post outcome : Semantics.trace -> mem -> locals -> MetricLog -> Prop) @@ -481,10 +501,7 @@ Section Proofs. (of_list (list_union Z.eqb (List.firstn binds_count (reg_class.all reg_class.arg)) [])) (singleton_set RegisterNames.ra)) (getRegs finalL) /\ - (getMetrics finalL - Platform.MetricLogging.addMetricInstructions 100 - (Platform.MetricLogging.addMetricJumps 100 - (Platform.MetricLogging.addMetricLoads 100 - (Platform.MetricLogging.addMetricStores 100 (getMetrics mach)))) <= + (getMetrics finalL - cost_compile_spec (getMetrics mach) <= lowerMetrics (finalMetricsH - mc))%metricsL /\ goodMachine finalTrace finalMH finalRegsH g finalL). Proof. @@ -1095,6 +1112,7 @@ Section Proofs. end end. cbn in H2p6. + (* cost_compile_spec constraint: cost_compile_spec >= (...93...) i think? *) blia. + rename l into lH, finalRegsH into lFH', finalRegsH' into lH', st0 into lFH, @@ -1300,8 +1318,7 @@ Section Proofs. eapply List.Forall_filter. intros *. intro E. destr (reg_class.get a); try discriminate E. unfold reg_class.get in E0. fwd. - unfold FlatToRiscvDef.valid_FlatImp_var. - destruct_one_match_hyp. + unfold FlatToRiscvDef.valid_FlatImp_var. destruct_one_match_hyp. -- fwd. blia. -- destruct_one_match_hyp. 1: discriminate. destruct_one_match_hyp; discriminate. @@ -1333,6 +1350,15 @@ Section Proofs. Qed. + Ltac finishcost := + scost_unfold; + repeat match goal with + | H : ForallVars_bcond _ ?cond |- _ => destruct cond eqn:?; unfold ForallVars_bcond in H + | H : valid_FlatImp_var _ |- _ => apply valid_FlatImp_var_isRegZ in H; rewrite H in * + | H : _ /\ _ |- _ => destruct H + end; + MetricsToRiscv.solve_MetricLog. + Lemma compile_stmt_correct: (forall resvars extcall argvars, compiles_FlatToRiscv_correctly compile_ext_call @@ -1597,18 +1623,17 @@ Section Proofs. replace mid_log with t in *. forget (Datatypes.length binds) as binds_count. subst binds. - eapply runsTo_weaken. { - match goal with - | H: (binds_count <= 8)%nat |- _ => rename H into BC - end. - move BC after OC. - repeat match goal with - | x := _ |- _ => clearbody x - end. - clear - word_ok RVM PRParams PR ext_spec word_riscv_ok locals_ok mem_ok pos_map_ok env_ok - IHexec OC BC OL Exb GetMany Ext GE FS C V Mo Mo' Gra RaM GPC A GM. - revert IHexec OC BC OL Exb GetMany Ext GE FS C V Mo Mo' Gra RaM GPC A GM. - eapply compile_function_body_correct. + eapply runsTo_weaken. + 1:{ + match goal with + | H: (binds_count <= 8)%nat |- _ => rename H into BC + end. + move BC after OC. + repeat match goal with + | x := _ |- _ => clearbody x + end. + revert IHexec OC BC OL Exb GetMany Ext GE FS C V Mo Mo' Gra RaM GPC A GM. + eapply compile_function_body_correct. } subst mach. simpl_MetricRiscvMachine_get_set. intros. fwd. eexists. eexists. eexists. eexists. @@ -1616,6 +1641,8 @@ Section Proofs. split; eauto 8 with map_hints. split; eauto 8 with map_hints. split; eauto 8 with map_hints. + (* cost_compile_spec constraint: cost_compile_spec + (1,1,1,0) <= cost_call *) + unfold cost_compile_spec, cost_call in *. MetricsToRiscv.solve_MetricLog. - idtac "Case compile_stmt_correct/SLoad". @@ -1628,6 +1655,7 @@ Section Proofs. inline_iff1. run1det. clear H0. (* <-- TODO this should not be needed *) run1done. + - idtac "Case compile_stmt_correct/SStore". inline_iff1. simpl_MetricRiscvMachine_get_set. @@ -1838,6 +1866,10 @@ Section Proofs. } run1done. cbn. + cost_unfold. + repeat match goal with + | H : valid_FlatImp_var _ |- _ => apply valid_FlatImp_var_isRegZ in H; try rewrite H in * + end. remember (updateMetricsForLiteral v initialL_metrics) as finalMetrics; symmetry in HeqfinalMetrics; pose proof update_metrics_for_literal_bounded (width := width) as Hlit; @@ -1853,7 +1885,7 @@ Section Proofs. match goal with | op: Syntax.bopname.bopname |- _ => destr op end. - all: match goal with + all: scost_unfold; match goal with | y: operand, H: context[Syntax.bopname.eq] |- _ => destr y; simpl in *; [ run1det; run1det; run1done; @@ -1924,7 +1956,7 @@ Section Proofs. { eapply run_Jal0; try safe_sidecond. solve_divisibleBy4. } simpl_MetricRiscvMachine_get_set. - intros. destruct_RiscvMachine mid. fwd. run1done. + intros. destruct_RiscvMachine mid. fwd. run1done. finishcost. - idtac "Case compile_stmt_correct/SIf/Else". (* execute branch instruction, which will jump over then-branch *) @@ -1949,7 +1981,7 @@ Section Proofs. all: try safe_sidecond. * (* at end of else-branch, i.e. also at end of if-then-else, just prove that computed post satisfies required post *) - simpl. intros. destruct_RiscvMachine middle. fwd. subst. run1done. + simpl. intros. destruct_RiscvMachine middle. fwd. subst. run1done. finishcost. - idtac "Case compile_stmt_correct/SLoop". match goal with @@ -2021,7 +2053,8 @@ Section Proofs. all: try safe_sidecond. } (* at end of loop, just prove that computed post satisfies required post *) - simpl. intros. destruct_RiscvMachine middle. fwd. run1done. + simpl. intros. destruct_RiscvMachine middle. fwd. run1done. finishcost. + * (* false: done, jump over body2 *) eapply runsToStep. { eapply compile_bcond_by_inverting_correct with (l := lH') (b := false); @@ -2029,7 +2062,7 @@ Section Proofs. try safe_sidecond. } simpl_MetricRiscvMachine_get_set. - intros. destruct_RiscvMachine mid. fwd. run1done. + intros. destruct_RiscvMachine mid. fwd. run1done. finishcost. - idtac "Case compile_stmt_correct/SSeq". on hyp[(FlatImpConstraints.uses_standard_arg_regs s1); runsTo] @@ -2062,3 +2095,4 @@ Section Proofs. End Proofs. + diff --git a/compiler/src/compiler/FlatToRiscvMetric.v b/compiler/src/compiler/FlatToRiscvMetric.v deleted file mode 100644 index 8703d7c1d..000000000 --- a/compiler/src/compiler/FlatToRiscvMetric.v +++ /dev/null @@ -1,378 +0,0 @@ -Require Import riscv.Spec.Primitives. -Require Import riscv.Platform.RiscvMachine. -Require Import riscv.Platform.MetricRiscvMachine. -Require Import riscv.Platform.MetricLogging. -Require Import riscv.Utility.Utility. -Require Import riscv.Utility.runsToNonDet. -Require Import riscv.Utility.InstructionCoercions. -Require Import compiler.util.Common. -Require Import compiler.eqexact. -Require Import coqutil.Tactics.Simp. -Require Import compiler.on_hyp_containing. -Require Import compiler.SeparationLogic. -Require Export coqutil.Word.SimplWordExpr. -Require Import compiler.GoFlatToRiscv. -Require Import compiler.DivisibleBy4. -Require Import compiler.MetricsToRiscv. -Require Import compiler.FlatImp. -Require Import compiler.RiscvWordProperties. -Require Import compiler.FlatToRiscvDef. -Require Import compiler.FlatToRiscvCommon. -Require Import compiler.FlatToRiscvLiterals. -Require Import coqutil.Tactics.fwd. - -Open Scope ilist_scope. - -Local Arguments Z.mul: simpl never. -Local Arguments Z.add: simpl never. -Local Arguments Z.of_nat: simpl never. -Local Arguments Z.modulo : simpl never. -Local Arguments Z.pow: simpl never. -Local Arguments Z.sub: simpl never. - -Section Proofs. - Context {iset: Decode.InstructionSet}. - Context {pos_map: map.map String.string Z}. - Context (compile_ext_call: pos_map -> Z -> Z -> stmt Z -> list Decode.Instruction). - Context {width: Z} {BW: Bitwidth width} {word: word.word width} {word_ok: word.ok word}. - Context {word_riscv_ok: RiscvWordProperties.word.riscv_ok word}. - Context {locals: map.map Z word} {locals_ok: map.ok locals}. - Context {mem: map.map word byte} {mem_ok: map.ok mem}. - Context {env: map.map String.string (list Z * list Z * stmt Z)} {env_ok: map.ok env}. - Context {M: Type -> Type}. - Context {MM: Monads.Monad M}. - Context {RVM: Machine.RiscvProgram M word}. - Context {PRParams: PrimitivesParams M MetricRiscvMachine}. - Context {ext_spec: Semantics.ExtSpec} {ext_spec_ok: Semantics.ext_spec.ok ext_spec}. - Context {PR: MetricPrimitives.MetricPrimitives PRParams}. - Context {BWM: bitwidth_iset width iset}. - - Add Ring wring : (word.ring_theory (word := word)) - (preprocess [autorewrite with rew_word_morphism], - morphism (word.ring_morph (word := word)), - constants [word_cst]). - - Local Notation RiscvMachineL := (MetricRiscvMachine (width := width)). - - Ltac run1done := - apply runsToDone; - simpl_MetricRiscvMachine_get_set; - rewrite ?word.sru_ignores_hibits, - ?word.slu_ignores_hibits, - ?word.srs_ignores_hibits, - ?word.mulhuu_simpl, - ?word.divu0_simpl, - ?word.modu0_simpl; - simpl in *; - eexists; (* finalMH *) - eexists; (* finalMetricsH *) - repeat split; - simpl_word_exprs word_ok; - first - [ solve [eauto] - | solve_MetricLog - | solve_word_eq word_ok - | solve [wcancel_assumption] - | eapply rearrange_footpr_subset; [ eassumption | wwcancel ] - | solve [solve_valid_machine word_ok] - | idtac ]. - - Ltac IH_sidecondition := - simpl_word_exprs word_ok; - try solve - [ reflexivity - | auto - | solve_word_eq word_ok - | simpl; solve_divisibleBy4 - | solve_valid_machine word_ok - | eapply rearrange_footpr_subset; [ eassumption | wwcancel ] - | wcancel_assumption ]. - - Hypothesis no_ext_calls: forall t mGive action argvals outcome, - ext_spec t mGive action argvals outcome -> False. - - Hypothesis stackalloc_always_0: forall x n body t m (l: locals) mc post, - FlatImp.exec map.empty (SStackalloc x n body) t m l mc post -> n = 0. - - Hypothesis sp_always_set: forall l: locals, - map.get l RegisterNames.sp = Some (word.of_Z 42). - - (* not needed any more *) - Definition stmt_not_too_big(s: stmt Z): Prop := True. - - Lemma compile_stmt_correct: - forall (s: stmt Z) t initialMH initialRegsH postH initialMetricsH, - FlatImp.exec map.empty s t initialMH initialRegsH initialMetricsH postH -> - forall R Rexec (initialL: RiscvMachineL) insts pos, - compile_stmt iset compile_ext_call map.empty pos 12345678 s = insts -> - stmt_not_too_big s -> - valid_FlatImp_vars s -> - divisibleBy4 initialL.(getPc) -> - initialL.(getRegs) = initialRegsH -> - subset (footpr (program iset initialL.(getPc) insts * Rexec)%sep) (of_list initialL.(getXAddrs)) -> - (program iset initialL.(getPc) insts * Rexec * eq initialMH * R)%sep initialL.(getMem) -> - initialL.(getLog) = t -> - initialL.(getNextPc) = add initialL.(getPc) (word.of_Z 4) -> - valid_machine initialL -> - runsTo initialL (fun finalL => exists finalMH finalMetricsH, - postH finalL.(getLog) finalMH finalL.(getRegs) finalMetricsH /\ - subset (footpr (program iset initialL.(getPc) insts * Rexec)%sep) - (of_list finalL.(getXAddrs)) /\ - (program iset initialL.(getPc) insts * Rexec * eq finalMH * R)%sep finalL.(getMem) /\ - finalL.(getPc) = add initialL.(getPc) - (word.mul (word.of_Z 4) (word.of_Z (Z.of_nat (length insts)))) /\ - finalL.(getNextPc) = word.add finalL.(getPc) (word.of_Z 4) /\ - (finalL.(getMetrics) - initialL.(getMetrics) <= - lowerMetrics (finalMetricsH - initialMetricsH))%metricsL /\ - valid_machine finalL). - Proof. - induction 1; intros; - repeat match goal with - | m: _ |- _ => destruct_RiscvMachine m; simpl_MetricRiscvMachine_get_set - end; - simpl in *; - subst; - simp. - - - (* SInteract *) - exfalso. eapply no_ext_calls. eassumption. - - - (* SCall *) - lazymatch goal with - | A: map.get map.empty _ = Some _ |- _ => - exfalso; simpl in *; - rewrite map.get_empty in A - end. - discriminate. - - - (* SLoad *) - unfold Memory.load, Memory.load_Z in *. simp. subst_load_bytes_for_eq. - run1det. run1done. - - - (* SStore *) - simpl_MetricRiscvMachine_get_set. - assert ((eq m * (program iset initialL_pc [[compile_store iset sz a v o]] * Rexec * R))%sep - initialL_mem) as A by ecancel_assumption. - match goal with - | H: _ |- _ => pose proof (store_bytes_frame H A) as P; move H at bottom; - unfold Memory.store, Memory.store_Z, Memory.store_bytes in H - end. - destruct P as (finalML & P1 & P2). - simp. - destruct (eq_sym (LittleEndianList.length_le_split (Memory.bytes_per(width:=width) sz) (word.unsigned val))) in t0, E. - subst_load_bytes_for_eq. - run1det. run1done. - eapply preserve_subset_of_xAddrs. 1: assumption. - ecancel_assumption. - - - (* SInlinetable *) - run1det. - assert (map.get (map.put l x (word.add initialL_pc (word.of_Z 4))) i = Some index). { - rewrite map.get_put_diff by congruence. assumption. - } - run1det. - assert (Memory.load sz initialL_mem - (word.add (word.add (word.add initialL_pc (word.of_Z 4)) index) (word.of_Z 0)) - = Some v). { - rewrite word.add_0_r. - eapply load_from_compile_byte_list. 1: eassumption. - wcancel_assumption. - } - run1det. - rewrite !map.put_put_same in *. - run1done. - - - (* SStackalloc *) - assert (valid_register RegisterNames.sp) by (cbv; auto). - specialize (stackalloc_always_0 x n body t mSmall l mc post). move stackalloc_always_0 at bottom. - assert (n = 0). { - eapply stackalloc_always_0. econstructor; eauto. - } - subst n. - run1det. - eapply runsTo_weaken. { - eapply H1 with (mStack := map.empty) (mCombined := mSmall). - { unfold Memory.anybytes. exists nil. reflexivity. } - { rewrite map.split_empty_r. reflexivity. } - all: IH_sidecondition. - } - simpl. - intros. - unfold Memory.anybytes, Memory.ftprint, map.of_disjoint_list_zip in *. simpl in *. - simp. - rewrite map.split_empty_r in H6p0p1. subst mSmall'. - repeat match goal with - | m: _ |- _ => destruct_RiscvMachine m; simpl_MetricRiscvMachine_get_set - end. - eexists. eexists. - split; [eassumption|]. - split; [solve [sidecondition]|]. - split; [solve [sidecondition]|]. - split. { - subst. solve_word_eq word_ok. - } - split; [solve [sidecondition]|]. - split. { - solve_MetricLog. - } - assumption. - - - (* SLit *) - RunInstruction.get_runsTo_valid_for_free. - eapply compile_lit_correct_full. - + sidecondition. - + use_sep_assumption. cbn. - (* ecancel. (* The term "Tree.Leaf (subset (footpr (program iset initialL_pc (compile_lit x v) * Rexec)%sep))" has type "Tree.Tree (set word -> Prop)" while it is expected to have type "Tree.Tree (?map -> Prop)". *) *) - eapply RelationClasses.reflexivity. - + unfold compile_stmt. simpl. ecancel_assumption. - + sidecondition. - + assumption. - + simpl. run1done; - remember (updateMetricsForLiteral v initialL_metrics) as finalMetrics; - symmetry in HeqfinalMetrics; - pose proof update_metrics_for_literal_bounded (width := width) as Hlit; - specialize Hlit with (1 := HeqfinalMetrics); - solve_MetricLog. - - - (* SOp *) - match goal with - | o: Syntax.bopname.bopname |- _ => destruct o - end; - simpl in *. - all: match goal with - | y: operand, H: context[Syntax.bopname.eq] |- _ => - destr y; simpl in *; - [ run1det; simpl_MetricRiscvMachine_get_set; run1det; run1done - | ]; try fwd - | y: operand |- _ => - destr y; simpl in *; - [ run1det; run1done - | ]; try fwd - end; simpl in *; fwd. - - all: try match goal with - | H: context[Decode.InvalidInstruction] |- _ => - assert (Encode.verify (Decode.InvalidInstruction (-1)) iset \/ - valid_InvalidInstruction (Decode.InvalidInstruction (-1))) by - ( eapply invert_ptsto_instr; ecancel_assumption) - | H: _ |- _ => run1det; run1done - end. - - all: - try match goal with - | H: Encode.verify (Decode.InvalidInstruction (-1)) iset \/ - valid_InvalidInstruction (Decode.InvalidInstruction (-1)) |- _ => - exfalso; destruct H; - [ unfold Encode.verify in H; simpl in H; - destruct H; assumption | unfold valid_InvalidInstruction in H]; fwd - end. - all: - try match goal with - | H: 0 <= -1 < 2^32 |- False - => destruct H; - match goal with - | H: 0 <= -1 |- False => destruct H; simpl; reflexivity - end - end. - all: simpl in *; fwd. - all: try match goal with - | H: ?post _ _ _ |- ?post _ _ _ => eqexact H - end. - all : rewrite ?word.srs_ignores_hibits, - ?word.sru_ignores_hibits, - ?word.slu_ignores_hibits, - ?word.mulhuu_simpl, - ?word.divu0_simpl, - ?word.modu0_simpl; trivial. - all: try solve_MetricLog. - simpl. rewrite reduce_eq_to_sub_and_lt. symmetry. apply map.put_put_same. - - - (* SSet *) - run1det. run1done. - - - (* SIf/Then *) - (* execute branch instruction, which will not jump *) - eapply runsTo_det_step_with_valid_machine; simpl in *; subst. - + assumption. - + simulate'. simpl_MetricRiscvMachine_get_set. - destruct cond; [destruct op | ]; - simpl in *; simp; repeat (simulate'; simpl_bools; simpl); try reflexivity. - + intro V. eapply runsTo_trans; simpl_MetricRiscvMachine_get_set. - * (* use IH for then-branch *) - eapply IHexec; IH_sidecondition. - * (* jump over else-branch *) - simpl. intros. destruct_RiscvMachine middle. simp. subst. - run1det. run1done. - - - (* SIf/Else *) - (* execute branch instruction, which will jump over then-branch *) - eapply runsTo_det_step_with_valid_machine; simpl in *; subst. - + assumption. - + simulate'. - destruct cond; [destruct op | ]; - simpl in *; simp; repeat (simulate'; simpl_bools; simpl); try reflexivity. - + intro V. eapply runsTo_trans; simpl_MetricRiscvMachine_get_set. - * (* use IH for else-branch *) - eapply IHexec; IH_sidecondition. - * (* at end of else-branch, i.e. also at end of if-then-else, just prove that - computed post satisfies required post *) - simpl. intros. destruct_RiscvMachine middle. simp. subst. run1done. - - - (* SLoop/again *) - on hyp[(stmt_not_too_big body1); runsTo] do (fun H => rename H into IH1). - on hyp[(stmt_not_too_big body2); runsTo] do (fun H => rename H into IH2). - on hyp[(stmt_not_too_big (SLoop body1 cond body2)); runsTo] do (fun H => rename H into IH12). - eapply runsTo_trans; simpl_MetricRiscvMachine_get_set. - + (* 1st application of IH: part 1 of loop body *) - eapply IH1; IH_sidecondition. - + simpl in *. simpl. intros. destruct_RiscvMachine middle. simp. subst. - destruct (eval_bcond middle_regs cond) as [condB|] eqn: E. - 2: exfalso; - match goal with - | H: context [_ <> None] |- _ => solve [eapply H; eauto] - end. - destruct condB. - * (* true: iterate again *) - eapply runsTo_det_step_with_valid_machine; simpl in *; subst. - { assumption. } - { simulate'. - destruct cond; [destruct op | ]; - simpl in *; simp; repeat (simulate'; simpl_bools; simpl); try reflexivity. } - { intro V. eapply runsTo_trans; simpl_MetricRiscvMachine_get_set. - - (* 2nd application of IH: part 2 of loop body *) - eapply IH2; IH_sidecondition; simpl_MetricRiscvMachine_get_set; - try eassumption; IH_sidecondition. - - simpl in *. simpl. intros. destruct_RiscvMachine middle. simp. subst. - (* jump back to beginning of loop: *) - run1det. - eapply runsTo_trans; simpl_MetricRiscvMachine_get_set. - + (* 3rd application of IH: run the whole loop again *) - eapply IH12 with (pos := pos); IH_sidecondition; simpl_MetricRiscvMachine_get_set; - try eassumption; IH_sidecondition. - + (* at end of loop, just prove that computed post satisfies required post *) - simpl. intros. destruct_RiscvMachine middle. simp. subst. - run1done. } - * (* false: done, jump over body2 *) - eapply runsTo_det_step_with_valid_machine; simpl in *; subst. - { assumption. } - { simulate'. - destruct cond; [destruct op | ]; - simpl in *; simp; repeat (simulate'; simpl_bools; simpl); try reflexivity. } - { intro V. simpl in *. run1done. } - - - (* SSeq *) - on hyp[(stmt_not_too_big s1); runsTo] do (fun H => rename H into IH1). - on hyp[(stmt_not_too_big s2); runsTo] do (fun H => rename H into IH2). - eapply runsTo_trans. - + eapply IH1; IH_sidecondition. - + simpl. intros. destruct_RiscvMachine middle. simp. subst. - eapply runsTo_trans. - * eapply IH2; IH_sidecondition; simpl_MetricRiscvMachine_get_set; - try eassumption; IH_sidecondition. - * simpl. intros. destruct_RiscvMachine middle. simp. subst. run1done. - - - (* SSkip *) - run1done. - Qed. - -End Proofs. diff --git a/compiler/src/compiler/FlattenExpr.v b/compiler/src/compiler/FlattenExpr.v index f30cf1465..516c483e3 100644 --- a/compiler/src/compiler/FlattenExpr.v +++ b/compiler/src/compiler/FlattenExpr.v @@ -7,6 +7,7 @@ Require Import coqutil.Decidable. Require Import coqutil.Word.Bitwidth. Require Import bedrock2.Syntax. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import bedrock2.Semantics bedrock2.MetricSemantics. Require Import coqutil.Macros.unique. Require Import Coq.Bool.Bool. @@ -343,13 +344,15 @@ Section FlattenExpr1. simpl (disjoint _ _) in *; map_solver locals_ok. + Local Notation exec := (FlatImp.exec PreSpill isRegStr). + Lemma seq_with_modVars: forall env t m (l: locals) mc s1 s2 mid post, - FlatImp.exec env s1 t m l mc mid -> + exec env s1 t m l mc mid -> (forall t' m' l' mc', mid t' m' l' mc' -> map.only_differ l (FlatImp.modVars s1) l' -> - FlatImp.exec env s2 t' m' l' mc' post) -> - FlatImp.exec env (FlatImp.SSeq s1 s2) t m l mc post. + exec env s2 t' m' l' mc' post) -> + exec env (FlatImp.SSeq s1 s2) t m l mc post. Proof. intros *. intros E1 E2. eapply @FlatImp.exec.seq. - eapply FlatImp.exec.intersect. @@ -371,25 +374,25 @@ Section FlattenExpr1. map.undef_on initialH (allFreshVars ngs1) -> disjoint (union (ExprImp.allVars_expr e) (of_option oResVar)) (allFreshVars ngs1) -> eval_expr initialM initialH e initialMcH = Some (res, finalMcH) -> - FlatImp.exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => + exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => t' = t /\ finalM = initialM /\ map.get finalL resVar = Some res /\ (finalMcL - initialMcL <= finalMcH - initialMcH)%metricsH). Proof. induction e; intros *; intros F Ex U D Ev; simpl in *; simp. - (* expr.literal *) - eapply @FlatImp.exec.lit; t_safe; solve_MetricLog. + eapply @FlatImp.exec.lit; t_safe. cost_hammer. - (* expr.var *) destruct oResVar; simp. - + eapply @FlatImp.exec.set; t_safe; [maps | solve_MetricLog]. - + eapply @FlatImp.exec.skip; t_safe. solve_MetricLog. + + eapply @FlatImp.exec.set; t_safe; [maps | cost_hammer]. + + eapply @FlatImp.exec.skip; t_safe. cost_hammer. - (* expr.load *) eapply @FlatImp.exec.seq. + eapply IHe; try eassumption. maps. + intros. simpl in *. simp. - eapply @FlatImp.exec.load; t_safe; rewrite ?word.add_0_r; try eassumption; solve_MetricLog. + eapply @FlatImp.exec.load; t_safe; rewrite ?word.add_0_r; try eassumption. cost_hammer. - (* expr.inlinetable *) repeat match goal with @@ -400,7 +403,7 @@ Section FlattenExpr1. + eapply IHe; try eassumption. 1: maps. set_solver; destr (String.eqb s0 x); subst; tauto. (* TODO improve set_solver? *) + intros. simpl in *. simp. - eapply @FlatImp.exec.inlinetable; t_safe; try eassumption. 2: solve_MetricLog. + eapply @FlatImp.exec.inlinetable; t_safe; try eassumption. 2: cost_hammer. apply_in_hyps flattenExpr_uses_Some_resVar. subst s0. intro C. subst s2. destruct oResVar. @@ -416,7 +419,8 @@ Section FlattenExpr1. * eapply IHe2. 1: eassumption. 4: eassumption. 1,2: solve [maps]. clear IHe1 IHe2. pose_flatten_var_ineqs. set_solver. * intros. simpl in *. simp. clear IHe1 IHe2. - eapply @FlatImp.exec.op; t_safe; t_safe. 2 : solve_MetricLog. + + eapply @FlatImp.exec.op; t_safe; t_safe. 2: FlatImp.scost_hammer. eapply flattenExpr_valid_resVar in E1; simpl; maps. - (* expr.ite *) @@ -458,7 +462,7 @@ Section FlattenExpr1. eapply subset_union_rr. eapply subset_refl. } - ++ cbv beta. intros. simp. t_safe. 2: solve_MetricLog. + ++ cbv beta. intros. simp. t_safe. 2: FlatImp.scost_hammer. clear IHe1 IHe2 IHe3. apply_in_hyps flattenExpr_uses_Some_resVar. subst. assumption. * eapply FlatImp.exec.if_true. @@ -492,7 +496,7 @@ Section FlattenExpr1. eapply subset_union_rl. eapply subset_refl. } - ++ cbv beta. intros. simp. t_safe. 2: solve_MetricLog. + ++ cbv beta. intros. simp. t_safe. 2: FlatImp.scost_hammer. clear IHe1 IHe2 IHe3. apply_in_hyps flattenExpr_uses_Some_resVar. subst. assumption. Qed. @@ -504,7 +508,7 @@ Section FlattenExpr1. map.undef_on lH (allFreshVars ngs1) -> disjoint (union (ExprImp.allVars_expr e) (of_option oResVar)) (allFreshVars ngs1) -> eval_expr m lH e initialMcH = Some (res, finalMcH) -> - FlatImp.exec fenv s t m lL initialMcL (fun t' m' lL' finalMcL => + exec fenv s t m lL initialMcL (fun t' m' lL' finalMcL => map.only_differ lL (FlatImp.modVars s) lL' /\ t' = t /\ m' = m /\ map.get lL' resVar = Some res /\ (finalMcL - initialMcL <= finalMcH - initialMcH)%metricsH). @@ -523,7 +527,7 @@ Section FlattenExpr1. disjoint (ExprImp.allVars_exprs es) (allFreshVars ngs1) -> eval_call_args m lH es initialMcH = Some (resVals, finalMcH) -> (* List.option_all (List.map (eval_expr m lH) es) = Some resVals -> *) - FlatImp.exec fenv s t m lL initialMcL (fun t' m' lL' finalMcL => + exec fenv s t m lL initialMcL (fun t' m' lL' finalMcL => t' = t /\ m' = m /\ map.getmany_of_list lL' resVars = Some resVals /\ map.only_differ lL (FlatImp.modVars s) lL' /\ @@ -596,7 +600,7 @@ Section FlattenExpr1. map.undef_on initialH (allFreshVars ngs1) -> disjoint (ExprImp.allVars_expr e) (allFreshVars ngs1) -> eval_expr initialM initialH e initialMcH = Some (res, finalMcH) -> - FlatImp.exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => + exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => t' = t /\ finalM = initialM /\ FlatImp.eval_bcond finalL resCond = Some (negb (word.eqb res (word.of_Z 0))) /\ (finalMcL - initialMcL <= finalMcH - initialMcH)%metricsH). @@ -617,7 +621,7 @@ Section FlattenExpr1. | intros; simpl in *; simp; default_flattenBooleanExpr ]. - 2, 4, 6 : solve_MetricLog. + 2, 4, 6 : cost_hammer. all: rewrite bool_to_word_to_bool_id; destruct_one_match; @@ -633,7 +637,7 @@ Section FlattenExpr1. map.undef_on initialH (allFreshVars ngs1) -> disjoint (ExprImp.allVars_expr e) (allFreshVars ngs1) -> eval_expr initialM initialH e initialMcH = Some (res, finalMcH) -> - FlatImp.exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => + exec fenv s t initialM initialL initialMcL (fun t' finalM finalL finalMcL => (t' = t /\ finalM = initialM /\ FlatImp.eval_bcond finalL resCond = Some (negb (word.eqb res (word.of_Z 0))) /\ (finalMcL - initialMcL <= finalMcH - initialMcH)%metricsH) /\ @@ -690,7 +694,7 @@ Section FlattenExpr1. map.extends lL lH -> map.undef_on lH (allFreshVars ngs) -> disjoint (ExprImp.allVars_cmd sH) (allFreshVars ngs) -> - FlatImp.exec eL sL t m lL mcL (fun t' m' lL' mcL' => exists lH' mcH', + exec eL sL t m lL mcL (fun t' m' lL' mcL' => exists lH' mcH', post t' m' lH' mcH' /\ (* <-- put first so that eassumption will instantiate lH' correctly *) map.extends lL' lH' /\ (* this one is a property purely about ExprImp (it's the conclusion of @@ -720,10 +724,10 @@ Section FlattenExpr1. specialize P with (1 := E). subst. maps. * maps. - * solve_MetricLog. - * solve_MetricLog. - * solve_MetricLog. - * solve_MetricLog. + * cost_hammer. + * cost_hammer. + * cost_hammer. + * cost_hammer. - (* exec.unset *) eapply @FlatImp.exec.skip. @@ -738,7 +742,7 @@ Section FlattenExpr1. * intros. simpl in *. simp. eapply @FlatImp.exec.store; rewrite ?word.add_0_r; try eassumption. { eapply flattenExpr_valid_resVar in E; maps. } - { repeat eexists; repeat (split || eassumption || solve_MetricLog); maps. } + { repeat eexists; repeat (split || eassumption || solve_MetricLog); try maps. all: cost_hammer. } - (* exec.stackalloc *) eapply @FlatImp.exec.stackalloc. 1: eassumption. @@ -746,7 +750,7 @@ Section FlattenExpr1. eapply @FlatImp.exec.weaken. { eapply IHexec; try reflexivity; try eassumption; maps. } { intros. simpl in *. simp. do 2 eexists. ssplit; try eassumption. - do 2 eexists. ssplit; try eassumption; try solve_MetricLog. maps. } + do 2 eexists. ssplit; try eassumption; try solve_MetricLog; try maps. cost_hammer. } - (* if_true *) eapply @FlatImp.exec.seq. @@ -763,7 +767,7 @@ Section FlattenExpr1. eapply @FlatImp.exec.weaken. { eapply IHexec; try reflexivity; try eassumption; maps. } { intros. simpl in *. simp. - repeat eexists; repeat (split || eassumption || solve_MetricLog). maps. } + repeat eexists; repeat (split || eassumption || solve_MetricLog); try maps. all: FlatImp.scost_hammer. } - (* if_false *) eapply @FlatImp.exec.seq. @@ -780,7 +784,7 @@ Section FlattenExpr1. eapply @FlatImp.exec.weaken. { eapply IHexec; try reflexivity; try eassumption; maps. } { intros. simpl in *. simp. - repeat eexists; repeat (split || eassumption || solve_MetricLog). maps. } + repeat eexists; repeat (split || eassumption || solve_MetricLog); try maps. all: FlatImp.scost_hammer. } - (* seq *) eapply seq_with_modVars. @@ -812,7 +816,7 @@ Section FlattenExpr1. | intros; simpl in *; simp .. ]. + maps. + congruence. - + repeat eexists; repeat (split || eassumption || solve_MetricLog); maps. + + repeat eexists; repeat (split || eassumption || solve_MetricLog); try maps. all: FlatImp.scost_hammer. + exfalso. match goal with | H: context [word.eqb _ _] |- _ => rewrite word.eqb_eq in H @@ -857,7 +861,7 @@ Section FlattenExpr1. 1,3: solve [maps]. pose proof (ExprImp.modVars_subset_allVars c). maps. * simpl. intros. simp. - repeat eexists; repeat (split || eassumption || solve_MetricLog). maps. + repeat eexists; repeat (split || eassumption || solve_MetricLog); try maps. all: FlatImp.scost_hammer. - (* call *) unfold flattenCall in *. simp. @@ -901,7 +905,7 @@ Section FlattenExpr1. -- eassumption. -- do 2 eexists. ssplit; try eassumption. ++ simple eapply map.only_differ_putmany; eassumption. - ++ solve_MetricLog. + ++ FlatImp.scost_hammer. - (* interact *) unfold flattenInteract in *. simp. @@ -922,7 +926,7 @@ Section FlattenExpr1. split; eauto. simple apply conj; [eassumption|]. split; [simple eapply map.only_differ_putmany; eassumption|]. - solve_MetricLog. + cost_hammer. Qed. Goal True. idtac "FlattenExpr: flattenStmt_correct_aux done". Abort. @@ -930,7 +934,7 @@ Section FlattenExpr1. flatten_functions eH = Success eL -> ExprImp2FlatImp sH = sL -> MetricSemantics.exec eH sH t m map.empty mc post -> - FlatImp.exec eL sL t m lL mc (fun t' m' lL' mcL' => exists lH' mcH', + exec eL sL t m lL mc (fun t' m' lL' mcL' => exists lH' mcH', post t' m' lH' mcH' /\ map.extends lL' lH' /\ (mcL' - mc <= mcH' - mc)%metricsH). diff --git a/compiler/src/compiler/LowerPipeline.v b/compiler/src/compiler/LowerPipeline.v index b76e3afcf..37ed158ff 100644 --- a/compiler/src/compiler/LowerPipeline.v +++ b/compiler/src/compiler/LowerPipeline.v @@ -21,6 +21,7 @@ Require Import compiler.FlatToRiscvDef. Require Import compiler.FlatToRiscvCommon. Require Import compiler.FlatToRiscvFunctions. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import compiler.FitsStack. Require Import compiler.Registers. Require Import riscv.Utility.InstructionCoercions. @@ -32,6 +33,17 @@ Local Arguments Z.modulo : simpl never. Local Arguments Z.pow: simpl never. Local Arguments Z.sub: simpl never. +Lemma raise_metrics_ineq : forall m1 m2, + (m1 <= m2)%metricsL -> (raiseMetrics m1 <= raiseMetrics m2)%metricsH. +Proof. + intros. + destr m1; destr m2. + destruct H as (?&?&?&?). + unfold_MetricLog. + simpl in *. + repeat split; assumption. +Qed. + Section WithWordAndMem. Context {width: Z} {word: word.word width} {mem: map.map word byte}. @@ -376,14 +388,15 @@ Section LowerPipeline. (FlatImp.SInteract resvars extcall argvars). Definition riscv_call(p: list Instruction * pos_map * Z) - (f_name: string)(t: Semantics.trace)(mH: mem)(argvals: list word) - (post: Semantics.trace -> mem -> list word -> Prop): Prop := + (f_name: string)(t: Semantics.trace)(mH: mem)(argvals: list word)(mc: MetricLog) + (post: Semantics.trace -> mem -> list word -> MetricLog -> Prop): Prop := let '(instrs, finfo, req_stack_size) := p in exists f_rel_pos, map.get finfo f_name = Some f_rel_pos /\ forall p_funcs stack_start stack_pastend ret_addr Rdata Rexec (initial: MetricRiscvMachine), map.get initial.(getRegs) RegisterNames.ra = Some ret_addr -> initial.(getLog) = t -> + raiseMetrics (cost_compile_spec initial.(getMetrics)) = mc -> word.unsigned ret_addr mod 4 = 0 -> arg_regs_contain initial.(getRegs) argvals -> req_stack_size <= word.unsigned (word.sub stack_pastend stack_start) / bytes_per_word -> @@ -392,7 +405,7 @@ Section LowerPipeline. machine_ok p_funcs stack_start stack_pastend instrs mH Rdata Rexec initial -> runsTo initial (fun final => exists mH' retvals, arg_regs_contain final.(getRegs) retvals /\ - post final.(getLog) mH' retvals /\ + post final.(getLog) mH' retvals (raiseMetrics final.(getMetrics)) /\ map.only_differ initial.(getRegs) reg_class.caller_saved final.(getRegs) /\ final.(getPc) = ret_addr /\ machine_ok p_funcs stack_start stack_pastend instrs mH' Rdata Rexec final). @@ -468,16 +481,18 @@ Section LowerPipeline. Qed. Lemma flat_to_riscv_correct: forall p1 p2, - map.forall_values FlatToRiscvDef.valid_FlatImp_fun p1 -> + map.forall_values FlatToRiscvDef.valid_FlatImp_fun p1 -> riscvPhase p1 = Success p2 -> - forall fname t m argvals post, + forall fname t m argvals mcH post, (exists argnames retnames fbody l, map.get p1 fname = Some (argnames, retnames, fbody) /\ map.of_list_zip argnames argvals = Some l /\ - forall mc, FlatImp.exec p1 fbody t m l mc (fun t' m' l' mc' => + FlatImp.exec PostSpill isRegZ p1 fbody t m l mcH (fun t' m' l' mc' => exists retvals, map.getmany_of_list l' retnames = Some retvals /\ - post t' m' retvals)) -> - riscv_call p2 fname t m argvals post. + post t' m' retvals mc')) -> + forall mcL, + riscv_call p2 fname t m argvals mcL (fun t m a mcL' => + exists mcH', metricsLeq (mcL' - mcL) (mcH' - mcH) /\ post t m a mcH'). Proof. unfold riscv_call. intros. destruct p2 as ((finstrs & finfo) & req_stack_size). @@ -517,7 +532,7 @@ Section LowerPipeline. (l := l) (post := fun t' m' l' mc' => (exists retvals, - map.getmany_of_list l' retnames = Some retvals /\ post t' m' retvals)). + map.getmany_of_list l' retnames = Some retvals /\ post t' m' retvals mc')). eapply Q with (g := {| rem_stackwords := word.unsigned (word.sub stack_pastend stack_start) / bytes_per_word; @@ -676,7 +691,13 @@ Section LowerPipeline. symmetry. eapply map.getmany_of_list_length. exact GM. - + eassumption. + + eexists. split; [|eassumption]. + subst. + apply raise_metrics_ineq in H10p3. + unfold_MetricLog. + cbn in H10p3. + cbn. + solve_MetricLog. + eapply only_differ_subset. 1: eassumption. rewrite ListSet.of_list_list_union. rewrite ?singleton_set_eq_of_list. @@ -770,8 +791,6 @@ Section LowerPipeline. + assumption. + assumption. + assumption. - Unshelve. - all: try exact EmptyMetricLog. Qed. End LowerPipeline. diff --git a/compiler/src/compiler/MMIO.v b/compiler/src/compiler/MMIO.v index 52bfedd47..df348f136 100644 --- a/compiler/src/compiler/MMIO.v +++ b/compiler/src/compiler/MMIO.v @@ -369,7 +369,7 @@ Section MMIO1. split; eauto. split; [unfold map.only_differ; eauto|]. split. { - unfold id. MetricsToRiscv.solve_MetricLog. + unfold id, MetricCosts.cost_interact. MetricsToRiscv.solve_MetricLog. } split; eauto. split; eauto. @@ -522,7 +522,7 @@ Section MMIO1. destruct_one_match; auto. } split. { - unfold id. MetricsToRiscv.solve_MetricLog. + unfold id, MetricCosts.cost_interact. MetricsToRiscv.solve_MetricLog. } split. { eapply map.put_extends. eassumption. diff --git a/compiler/src/compiler/MetricsToRiscv.v b/compiler/src/compiler/MetricsToRiscv.v index ac5423442..3f57d77ba 100644 --- a/compiler/src/compiler/MetricsToRiscv.v +++ b/compiler/src/compiler/MetricsToRiscv.v @@ -23,6 +23,13 @@ Section MetricsToRiscv. storesL := storesH mh; |}. + Definition raiseMetrics (ml: metricsL): metricsH := {| + instructionsH := instructionsL ml; + jumpsH := jumpsL ml; + loadsH := loadsL ml; + storesH := storesL ml; + |}. + End MetricsToRiscv. Ltac solve_MetricLog := diff --git a/compiler/src/compiler/Pipeline.v b/compiler/src/compiler/Pipeline.v index d57b1767c..f91f428e8 100644 --- a/compiler/src/compiler/Pipeline.v +++ b/compiler/src/compiler/Pipeline.v @@ -15,7 +15,6 @@ Require Import riscv.Utility.runsToNonDet. Require Export riscv.Platform.MetricRiscvMachine. Require Import coqutil.Z.Lia. Require Import coqutil.Tactics.fwd. -Require Import bedrock2.MetricLogging. Require Import compiler.ExprImp. Require Import compiler.FlattenExprDef. Require Import compiler.FlattenExpr. @@ -39,7 +38,9 @@ Require Import compiler.SeparationLogic. Require Import compiler.Spilling. Require Import compiler.RegAlloc. Require Import compiler.RiscvEventLoop. +Require Import compiler.MetricsToRiscv. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import compiler.FlatToRiscvCommon. Require Import compiler.FlatToRiscvFunctions. Require Import compiler.DivisibleBy4. @@ -47,10 +48,11 @@ Require Export coqutil.Word.SimplWordExpr. Require Export compiler.Registers. Require Import compiler.ForeverSafe. Require Import FunctionalExtensionality. +Require Import PropExtensionality. Require Import coqutil.Tactics.autoforward. Require Import compiler.FitsStack. Require Import compiler.LowerPipeline. -Require Import bedrock2.WeakestPreconditionProperties. +Require Import bedrock2.MetricWeakestPreconditionProperties. Require Import compiler.UseImmediateDef. Require Import compiler.UseImmediate. Require Import compiler.DeadCodeElimDef. @@ -64,8 +66,8 @@ Section WithWordAndMem. Program: Type; Valid: Program -> Prop; Call(p: Program)(funcname: string) - (t: trace)(m: mem)(argvals: list word) - (post: trace -> mem -> list word -> Prop): Prop; + (t: trace)(m: mem)(argvals: list word)(mc: MetricLog) + (post: trace -> mem -> list word -> MetricLog -> Prop): Prop; }. Record phase_correct{L1 L2: Lang} @@ -78,9 +80,12 @@ Section WithWordAndMem. phase_preserves_post: forall p1 p2, L1.(Valid) p1 -> compile p1 = Success p2 -> - forall fname t m argvals post, - L1.(Call) p1 fname t m argvals post -> - L2.(Call) p2 fname t m argvals post; + forall fname t m argvals mcH post, + L1.(Call) p1 fname t m argvals mcH post -> + forall mcL, + L2.(Call) p2 fname t m argvals mcL (fun t m a mcL' => + exists mcH', metricsLeq (mcL' - mcL) (mcH' - mcH) /\ post t m a mcH' + ); }. Arguments phase_correct : clear implicits. @@ -92,6 +97,10 @@ Section WithWordAndMem. | Failure e => Failure e end. + Ltac post_ext := + repeat (eapply functional_extensionality; intro); + apply propositional_extensionality. + Lemma compose_phases_correct{L1 L2 L3: Lang} {compile12: L1.(Program) -> result L2.(Program)} {compile23: L2.(Program) -> result L3.(Program)}: @@ -101,7 +110,10 @@ Section WithWordAndMem. Proof. unfold compose_phases. intros [V12 C12] [V23 C23]. - split; intros; fwd; eauto. + split; intros; fwd; [eauto|]. + erewrite f_equal; [eauto|]. + post_ext. + split; [|destruct 1 as (?&?&?&?&?)]; eauto with metric_arith. Qed. Section WithMoreParams. @@ -147,14 +159,29 @@ Section WithWordAndMem. Cmd -> trace -> mem -> locals -> MetricLog -> (trace -> mem -> locals -> MetricLog -> Prop) -> Prop) (e: string_keyed_map (list Var * list Var * Cmd)%type)(f: string) - (t: trace)(m: mem)(argvals: list word) - (post: trace -> mem -> list word -> Prop): Prop := + (t: trace)(m: mem)(argvals: list word)(mc: MetricLog) + (post: trace -> mem -> list word -> MetricLog -> Prop): Prop := + exists argnames retnames fbody l, + map.get e f = Some (argnames, retnames, fbody) /\ + map.of_list_zip argnames argvals = Some l /\ + Exec e fbody t m l (cost_spill_spec mc) (fun t' m' l' mc' => + exists retvals, map.getmany_of_list l' retnames = Some retvals /\ + post t' m' retvals mc'). + + Definition locals_based_call_spec_spilled{Var Cmd: Type}{locals: map.map Var word} + {string_keyed_map: forall T: Type, map.map string T} + (Exec: string_keyed_map (list Var * list Var * Cmd)%type -> + Cmd -> trace -> mem -> locals -> MetricLog -> + (trace -> mem -> locals -> MetricLog -> Prop) -> Prop) + (e: string_keyed_map (list Var * list Var * Cmd)%type)(f: string) + (t: trace)(m: mem)(argvals: list word)(mc: MetricLog) + (post: trace -> mem -> list word -> MetricLog -> Prop): Prop := exists argnames retnames fbody l, map.get e f = Some (argnames, retnames, fbody) /\ map.of_list_zip argnames argvals = Some l /\ - forall mc, Exec e fbody t m l mc (fun t' m' l' mc' => + Exec e fbody t m l mc (fun t' m' l' mc' => exists retvals, map.getmany_of_list l' retnames = Some retvals /\ - post t' m' retvals). + post t' m' retvals mc'). Definition ParamsNoDup{Var: Type}: (list Var * list Var * FlatImp.stmt Var) -> Prop := fun '(argnames, retnames, body) => NoDup argnames /\ NoDup retnames. @@ -170,7 +197,7 @@ Section WithWordAndMem. Definition FlatWithStrVars: Lang := {| Program := string_keyed_map (list string * list string * FlatImp.stmt string); Valid := map.forall_values ParamsNoDup; - Call := locals_based_call_spec FlatImp.exec; + Call := locals_based_call_spec (FlatImp.exec PreSpill isRegStr); |}. (* | *) @@ -189,7 +216,7 @@ Section WithWordAndMem. Definition FlatWithZVars: Lang := {| Program := string_keyed_map (list Z * list Z * FlatImp.stmt Z); Valid := map.forall_values ParamsNoDup; - Call := locals_based_call_spec FlatImp.exec; + Call := locals_based_call_spec (FlatImp.exec PreSpill isRegZ); |}. (* | *) (* | Spilling *) @@ -197,7 +224,7 @@ Section WithWordAndMem. Definition FlatWithRegs: Lang := {| Program := string_keyed_map (list Z * list Z * FlatImp.stmt Z); Valid := map.forall_values FlatToRiscvDef.valid_FlatImp_fun; - Call := locals_based_call_spec FlatImp.exec; + Call := locals_based_call_spec_spilled (FlatImp.exec PostSpill isRegZ); |}. (* | *) (* | FlatToRiscv *) @@ -243,7 +270,7 @@ Section WithWordAndMem. eexists _, _, _, _. split. 1: eassumption. split. 1: eassumption. intros. eapply FlatImp.exec.weaken. - - eapply flattenStmt_correct_aux with (mcH := mc). + - eapply flattenStmt_correct_aux. + eassumption. + eauto. + reflexivity. @@ -264,7 +291,9 @@ Section WithWordAndMem. eapply start_state_spec. 2: exact A. eapply ListSet.In_list_union_l. eapply ListSet.In_list_union_l. assumption. + eapply @freshNameGenState_disjoint_fbody. - - simpl. intros. fwd. eauto using map.getmany_of_list_extends. + - simpl. intros. fwd. eexists. split. + + eauto using map.getmany_of_list_extends. + + eexists. split; [|eassumption]. unfold cost_spill_spec in *; solve_MetricLog. Qed. Lemma useimmediate_functions_NoDup: forall funs funs', @@ -302,9 +331,10 @@ Section WithWordAndMem. eexists _, _, _, _. split. 1: eassumption. split. 1: eassumption. intros. eapply exec.weaken. - - eapply useImmediate_correct_aux. - all: eauto. - - eauto. + - eapply useImmediate_correct_aux; eauto. + - simpl. destruct 1 as (?&?&?&?&?). + repeat (eexists; split; try eassumption). + unfold cost_spill_spec in *; solve_MetricLog. Qed. @@ -346,8 +376,7 @@ Section WithWordAndMem. exists retvals. split. + erewrite MapEauto.agree_on_getmany; [ eauto | eapply MapEauto.agree_on_comm; [ eassumption ] ]. - + eassumption. - Unshelve. eauto. + + eexists; split; eauto. unfold cost_spill_spec in *; solve_MetricLog. Qed. Lemma regalloc_functions_NoDup: forall funs funs', @@ -401,8 +430,9 @@ Section WithWordAndMem. edestruct putmany_of_list_zip_states_compat as (lL' & P' & Cp); try eassumption. 1: eapply states_compat_empty. rewrite H1 in P'. inversion P'. exact Cp. - - simpl. intros. fwd. eexists. split. 2: eassumption. - eauto using states_compat_getmany. + - simpl. intros. fwd. eexists. split. + + eauto using states_compat_getmany. + + eexists. split; [|eassumption]. unfold cost_spill_spec in *; solve_MetricLog. Qed. Ltac debool := @@ -456,7 +486,7 @@ Section WithWordAndMem. Proof. unfold FlatWithZVars, FlatWithRegs. split; cbn. 1: exact spilling_preserves_valid. - unfold locals_based_call_spec. intros. fwd. + unfold locals_based_call_spec, locals_based_call_spec_spilled. intros. fwd. pose proof H0 as GL. unfold spill_functions in GL. eapply map.try_map_values_fw in GL. 2: eassumption. @@ -472,8 +502,11 @@ Section WithWordAndMem. fwd. exists argnames2, retnames2, fbody2, l'. split. 1: exact G2. split. 1: eassumption. - intros. eapply spill_fun_correct; try eassumption. - unfold call_spec. intros * E. rewrite E in *. fwd. eauto. + intros. + eapply FlatImp.exec.weaken. + - eapply spill_fun_correct; try eassumption. + unfold call_spec. intros * E. rewrite E in *. fwd. eauto. + - simpl. intros. fwd. repeat (eexists; split; eauto with metric_arith). Qed. Lemma riscv_phase_correct: phase_correct FlatWithRegs RiscvLang (riscvPhase compile_ext_call). @@ -550,18 +583,19 @@ Section WithWordAndMem. (* function we choose to call: *) (fname: string) (* high-level initial state & post on final state: *) - (t: trace) (mH: mem) (argvals: list word) (post: trace -> mem -> list word -> Prop), + (t: trace) (mH: mem) (argvals: list word) (mc: MetricLog) (post: trace -> mem -> list word -> MetricLog -> Prop), valid_src_funs functions = true -> compile functions = Success (instrs, finfo, req_stack_size) -> (exists (argnames retnames: list string) (fbody: cmd) l, map.get (map.of_list (map := Semantics.env) functions) fname = Some (argnames, retnames, fbody) /\ map.of_list_zip argnames argvals = Some l /\ - forall mc, - MetricSemantics.exec (map.of_list functions) fbody t mH l mc + MetricSemantics.exec (map.of_list functions) fbody t mH l + (cost_spill_spec mc) (fun t' m' l' mc' => exists retvals: list word, map.getmany_of_list l' retnames = Some retvals /\ - post t' m' retvals)) -> + post t' m' retvals mc')) -> + forall mcL, exists (f_rel_pos: Z), map.get (map.of_list finfo) fname = Some f_rel_pos /\ forall (* low-level machine on which we're going to run the compiled program: *) @@ -575,11 +609,14 @@ Section WithWordAndMem. word.unsigned ret_addr mod 4 = 0 -> arg_regs_contain initial.(getRegs) argvals -> initial.(getLog) = t -> + raiseMetrics (cost_compile_spec initial.(getMetrics)) = mcL -> machine_ok p_funcs stack_lo stack_hi instrs mH Rdata Rexec initial -> runsTo initial (fun final : MetricRiscvMachine => exists mH' retvals, arg_regs_contain (getRegs final) retvals /\ - post final.(getLog) mH' retvals /\ + (exists mcH' : MetricLog, + ((raiseMetrics final.(getMetrics)) - mcL <= mcH' - mc)%metricsH /\ + post (getLog final) mH' retvals mcH') /\ map.only_differ initial.(getRegs) reg_class.caller_saved final.(getRegs) /\ final.(getPc) = ret_addr /\ machine_ok p_funcs stack_lo stack_hi instrs mH' Rdata Rexec final). @@ -596,6 +633,7 @@ Section WithWordAndMem. } specialize C with (1 := H0'). specialize C with (1 := H1). + specialize (C mcL). cbv iota in C. fwd. eauto 10. Qed. @@ -619,7 +657,7 @@ Section WithWordAndMem. (* function we choose to call: *) (fname: string) (f_rel_pos: Z) (* high-level initial state & post on final state: *) - (t: trace) (mH: mem) (argvals: list word) (post: trace -> mem -> list word -> Prop) + (t: trace) (mH: mem) (argvals: list word) (mc: MetricLog) (post: trace -> mem -> list word -> MetricLog -> Prop) (* ghost vars that help describe the low-level machine: *) (stack_lo stack_hi ret_addr p_funcs: word) (Rdata Rexec: mem -> Prop) (* low-level machine on which we're going to run the compiled program: *) @@ -627,7 +665,9 @@ Section WithWordAndMem. valid_src_funs fs = true -> NoDup (map fst fs) -> compile fs = Success (instrs, finfo, req_stack_size) -> - WeakestPrecondition.call (map.of_list fs) fname t mH argvals post -> + MetricWeakestPrecondition.call (map.of_list fs) fname t mH argvals + (cost_spill_spec mc) post -> + forall mcL, map.get (map.of_list finfo) fname = Some f_rel_pos -> req_stack_size <= word.unsigned (word.sub stack_hi stack_lo) / bytes_per_word -> word.unsigned (word.sub stack_hi stack_lo) mod bytes_per_word = 0 -> @@ -636,24 +676,27 @@ Section WithWordAndMem. word.unsigned ret_addr mod 4 = 0 -> arg_regs_contain initial.(getRegs) argvals -> initial.(getLog) = t -> + raiseMetrics (cost_compile_spec initial.(getMetrics)) = mcL -> machine_ok p_funcs stack_lo stack_hi instrs mH Rdata Rexec initial -> runsTo initial (fun final : MetricRiscvMachine => exists mH' retvals, arg_regs_contain (getRegs final) retvals /\ - post final.(getLog) mH' retvals /\ + (exists mcH' : MetricLog, + ((raiseMetrics final.(getMetrics)) - mcL <= mcH' - mc)%metricsH /\ + post (getLog final) mH' retvals mcH') /\ map.only_differ initial.(getRegs) reg_class.caller_saved final.(getRegs) /\ final.(getPc) = ret_addr /\ machine_ok p_funcs stack_lo stack_hi instrs mH' Rdata Rexec final). Proof. intros. - let H := hyp WeakestPrecondition.call in rename H into WP. + let H := hyp MetricWeakestPrecondition.call in rename H into WP. edestruct compiler_correct with (fname := fname) (argvals := argvals) (post := post) as (f_rel_pos' & G & C); try eassumption. 2: { eapply C; clear C; try assumption; try congruence; try eassumption. } intros. - unfold Semantics.call in WP. fwd. + unfold MetricSemantics.call in WP. fwd. do 5 eexists. 1: eassumption. split. 1: eassumption. - intros. eapply MetricSemantics.of_metrics_free. assumption. + intros. assumption. Qed. End WithMoreParams. diff --git a/compiler/src/compiler/RegAlloc.v b/compiler/src/compiler/RegAlloc.v index c5ed8d750..20ceb8882 100644 --- a/compiler/src/compiler/RegAlloc.v +++ b/compiler/src/compiler/RegAlloc.v @@ -9,6 +9,8 @@ Require Import coqutil.Datatypes.ListSet. Require Import coqutil.Tactics.fwd. Require Import coqutil.Tactics.autoforward. Require Import compiler.Registers. +Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Open Scope Z_scope. @@ -578,9 +580,22 @@ Definition assert(b: bool)(els: result unit): result unit := if b then Success t Definition mapping_eqb: srcvar * impvar -> srcvar * impvar -> bool := fun '(x, x') '(y, y') => andb (String.eqb x y) (Z.eqb x' y'). +Definition check_regs (x: srcvar) (x': impvar) : bool := + negb (andb (isRegStr x) (negb (isRegZ x'))). + +Definition check_regs_op (x: @operand srcvar) (x': @operand impvar) : bool := + match x, x' with + | Var vx, Var vx' => check_regs vx vx' + | Const cx, Const cx' => Z.eqb cx cx' + | _, _ => false + end. + Definition assert_in(y: srcvar)(y': impvar)(m: list (srcvar * impvar)): result unit := match List.find (mapping_eqb (y, y')) m with - | Some _ => Success tt + | Some _ => if check_regs y y' then Success tt else + error:("The register allocator found a mapping of source register variable" y + "to target stack variable" y' + "but source register variables must be in registers in the target.") | None => error:("The register allocator replaced source variable" y "by target variable" y' "but when the checker encountered this pair," @@ -616,7 +631,12 @@ Definition assert_ins(args: list srcvar)(args': list impvar)(m: list (srcvar * i (List.combine args args')) error:("Register allocation checker got a source variable list" args "and a target variable list" args' - "that are incompatible with its current mapping of source to target variables" m). + "that are incompatible with its current mapping of source to target variables" m);; + assert (List.forallb (fun '(x, x') => if check_regs x x' then true else false) + (List.combine args args')) + error:("Register allocation checker got a source variable list" args + "and a target variable list" args' + "in which at least one source register variable is on the stack in the target."). Definition check_bcond(m: list (srcvar * impvar))(c: bcond)(c': bcond'): result unit := match c, c' with @@ -630,13 +650,20 @@ Definition check_bcond(m: list (srcvar * impvar))(c: bcond)(c': bcond'): result | _, _ => error:("Register allocation checker cannot match" c "and" c') end. -Definition assignment(m: list (srcvar * impvar))(x: srcvar)(x': impvar): list (srcvar * impvar) := - (x, x') :: (remove_by_snd Z.eqb x' (remove_by_fst String.eqb x m)). +Definition assignment(m: list (srcvar * impvar))(x: srcvar)(x': impvar): result (list (srcvar * impvar)) := + if check_regs x x' then + Success ((x, x') :: (remove_by_snd Z.eqb x' (remove_by_fst String.eqb x m))) + else + error:("Register allocation checker got an assignment of source register variable" x + "to target stack variable" x' + "but source register variables must be in registers in the target."). Fixpoint assignments(m: list (srcvar * impvar))(xs: list srcvar)(xs': list impvar): result (list (srcvar * impvar)) := match xs, xs' with - | x :: xs0, x' :: xs0' => assignments (assignment m x x') xs0 xs0' + | x :: xs0, x' :: xs0' => + a <- assignment m x x';; + assignments a xs0 xs0' | nil, nil => Success m | _, _ => error:("Register allocator checker got variable lists of different length") end. @@ -672,7 +699,7 @@ Fixpoint check(m: list (srcvar * impvar))(s: stmt)(s': stmt'){struct s}: result assert_in y y' m;; assert (Z.eqb ofs ofs') err;; assert (access_size_beq sz sz') err;; - Success (assignment m x x') + assignment m x x' | SStore sz x y ofs, SStore sz' x' y' ofs' => assert_in x x' m;; assert_in y y' m;; @@ -685,20 +712,22 @@ Fixpoint check(m: list (srcvar * impvar))(s: stmt)(s': stmt'){struct s}: result assert (negb (Z.eqb x' y')) err;; assert (access_size_beq sz sz') err;; assert (List.list_eqb Byte.eqb bs bs') err;; - Success (assignment m x x') + assignment m x x' | SStackalloc x n body, SStackalloc x' n' body' => assert (Z.eqb n n') err;; - check (assignment m x x') body body' + a <- assignment m x x';; + check a body body' | SLit x z, SLit x' z' => assert (Z.eqb z z') err;; - Success (assignment m x x') + assignment m x x' | SOp x op y z, SOp x' op' y' z' => assert_in y y' m;; assert_in_op z z' m;; assert (bopname_beq op op') err;; - Success (assignment m x x') + assignment m x x' | SSet x y, SSet x' y' => - assert_in y y' m;; Success (assignment m x x') + assert_in y y' m;; + assignment m x x' | SIf c s1 s2, SIf c' s1' s2' => check_bcond m c c';; m1 <- check m s1 s1';; @@ -775,8 +804,8 @@ Lemma extends_cons: forall a l1 l2, Proof. unfold extends, assert_in. simpl. intros. fwd. destruct_one_match_hyp. 1: reflexivity. - epose proof (H _ _ _) as A. destruct_one_match_hyp. 2: discriminate. rewrite E1. reflexivity. - Unshelve. rewrite E. reflexivity. + epose proof (H _ _ _) as A. destruct_one_match_hyp. 2: discriminate. rewrite E2. reflexivity. + Unshelve. rewrite E. rewrite E0. reflexivity. Qed. Lemma extends_cons_l: forall a l, @@ -794,11 +823,12 @@ Lemma extends_cons_r: forall a l1 l2, Proof. unfold extends, assert_in. simpl. intros. fwd. destruct_one_match_hyp. 2: { - eapply H0. rewrite E. reflexivity. + specialize (H0 x x'). rewrite E in H0. rewrite E0 in H0. + eapply H0. reflexivity. } destruct_one_match. 1: reflexivity. - eapply find_none in E1. 2: eassumption. - simpl in E1. exfalso. fwd. intuition congruence. + eapply find_none in E2. 2: eassumption. + simpl in E2. exfalso. congruence. Qed. Lemma extends_intersect_l: forall l1 l2, @@ -1014,21 +1044,22 @@ Section CheckerCorrect. inversion E3. clear E3. cbn in *. fwd. erewrite states_compat_get; try eassumption. 2: { - unfold assert_in. unfold mapping_eqb. rewrite E1. reflexivity. + unfold assert_in. unfold mapping_eqb. rewrite E1. destruct H1p0 as [H1|]; [now rewrite H1|discriminate]. } unfold map.getmany_of_list in *. erewrite IHys; eauto. - unfold assert_ins. rewrite H1. rewrite Nat.eqb_refl. simpl. - unfold assert. rewrite E2p1. reflexivity. + unfold assert_ins. rewrite H1p1. rewrite H0. rewrite Nat.eqb_refl. simpl. + unfold assert. rewrite E4p1. reflexivity. Qed. - Lemma states_compat_put: forall lH corresp lL x x' v, + Lemma states_compat_put: forall lH corresp lL x x' v m, + assignment corresp x x' = Success m -> states_compat lH corresp lL -> - states_compat (map.put lH x v) (assignment corresp x x') (map.put lL x' v). + states_compat (map.put lH x v) m (map.put lL x' v). Proof. intros. unfold states_compat in *. intros k k'. intros. - rewrite map.get_put_dec. rewrite map.get_put_dec in H1. - unfold assert_in, assignment in H0. fwd. simpl in E. + rewrite map.get_put_dec. rewrite map.get_put_dec in H2. + unfold assert_in, assignment in H, H0, H1. fwd. simpl in E. rewrite String.eqb_sym, Z.eqb_sym in E. destr (Z.eqb x' k'). - destr (String.eqb x k). @@ -1047,10 +1078,10 @@ Section CheckerCorrect. eapply In_remove_by_fst in E. destruct E. destr (String.eqb x k). + exfalso. congruence. - + eapply H. 2: eassumption. unfold assert_in. - destruct_one_match. 1: reflexivity. - eapply find_none in E1. 2: eassumption. - simpl in E1. rewrite String.eqb_refl, Z.eqb_refl in E1. discriminate. + + eapply H0. 2: eassumption. unfold assert_in. + destruct_one_match. 1: rewrite E0; trivial. + eapply find_none in E3. 2: eassumption. + simpl in E3. rewrite String.eqb_refl, Z.eqb_refl in E3. discriminate. Qed. @@ -1069,7 +1100,7 @@ Section CheckerCorrect. - simpl in *. fwd. specialize IHbinds with (1 := H). rename l' into lH'. - edestruct IHbinds as (lL' & P & C). 1: eassumption. 1: eapply states_compat_put. 1: eassumption. + edestruct IHbinds as (lL' & P & C). 1: eassumption. 1: eapply states_compat_put. 1: eassumption. 1: eassumption. simpl. rewrite P. eauto. Qed. @@ -1078,7 +1109,8 @@ Section CheckerCorrect. Proof. induction xs; intros; destruct xs'; try discriminate. - reflexivity. - - simpl in *. f_equal. eapply IHxs. eassumption. + - simpl in *. f_equal. destruct_one_match_hyp; try discriminate. + specialize (IHxs xs' a0 m2). eauto. Qed. Lemma assert_ins_same_length: forall xs xs' m u, @@ -1087,20 +1119,89 @@ Section CheckerCorrect. unfold assert_ins, assert. intros. fwd. assumption. Qed. + Opaque isRegStr. + Opaque isRegZ. + + Lemma assert_in_then_check_regs_op : forall x x' corresp, + assert_in_op x x' corresp = Success tt -> check_regs_op x x' = true. + Proof. + unfold check_regs_op, assert_in_op, assert_in; intros; destruct x'; fwd; reflexivity. + Qed. + + Ltac discr_match_success := + match goal with + | H: match ?expr with _ => _ end = Success _ |- _ => destr expr; discriminate + end. + + Lemma check_regs_cost_SIf: + forall (cond : bcond) (bThen bElse : stmt) (corresp : list (string * Z)) + (cond0 : bcond') (s'1 s'2 : stmt') (a0 a1 : list (string * Z)) + (mc mcL : MetricLog), + check_bcond corresp cond cond0 = Success tt -> + check corresp bThen s'1 = Success a0 -> + check corresp bElse s'2 = Success a1 -> + forall mc' mcH' : MetricLog, + (mc' - exec.cost_SIf isRegZ cond0 mcL <= mcH' - exec.cost_SIf isRegStr cond mc)%metricsH -> + (mc' - mcL <= mcH' - mc)%metricsH. + Proof. + intros. + repeat (unfold check_bcond, assert_in, assignment in *; fwd). + destr cond; destr cond0; + repeat match goal with + | H: match ?expr with _ => _ end = Success _ |- _ => destr expr; try discriminate + end; + unfold check_regs, exec.cost_SIf in *; + cost_solve. + Qed. + + Lemma check_regs_cost_SLoop_false: + forall (cond : bcond) (body1 body2 : stmt) (corresp' : list (string * Z)) (s'1 : stmt') + (cond0 : bcond') (s'2 : stmt') (mc mcL : MetricLog) (a : list (string * Z)), + check a body1 s'1 = Success corresp' -> + check_bcond corresp' cond cond0 = Success tt -> + forall a2 : list (string * Z), + check corresp' body2 s'2 = Success a2 -> + forall mc' mcH' : MetricLog, + (mc' - mcL <= mcH' - mc)%metricsH -> + (exec.cost_SLoop_false isRegZ cond0 mc' - mcL <= exec.cost_SLoop_false isRegStr cond mcH' - mc)%metricsH. + Proof. + intros. + repeat (unfold check_bcond, assert_in, assignment in *; fwd). + destr cond; destr cond0; + repeat match goal with + | H: match ?expr with _ => _ end = Success _ |- _ => destr expr; try discriminate + end; + unfold check_regs, exec.cost_SLoop_false in *; + cost_solve. + Qed. Hint Constructors exec.exec : checker_hints. Hint Resolve states_compat_get : checker_hints. Hint Resolve states_compat_put : checker_hints. Hint Resolve states_compat_get_op : checker_hints. Hint Resolve states_compat_then_op : checker_hints. - Lemma checker_correct: forall (e: srcEnv) (e': impEnv) s t m lH mc post, + (*Hint Resolve*) + (* check_regs_cost_SLoad check_regs_cost_SStore check_regs_cost_SInlinetable*) + (* check_regs_cost_SStackalloc check_regs_cost_SLit (*check_regs_cost_SOp*)*) + (* check_regs_cost_SSet check_regs_cost_SIf check_regs_cost_SLoop_false*) + (* : checker_hints.*) + + Ltac a := + repeat match goal with | |- metricsLeq _ _ => fail 1 | _ => econstructor end; + eauto 10 with checker_hints. + + Ltac b := unfold assert_in, assignment, check_regs in *; cost_hammer. + + Lemma checker_correct: forall (e: srcEnv) (e': impEnv) s t m lH mcH post, check_funcs e e' = Success tt -> - exec e s t m lH mc post -> - forall lL corresp corresp' s', + exec PreSpill isRegStr e s t m lH mcH post -> + forall lL corresp corresp' s' mcL, check corresp s s' = Success corresp' -> states_compat lH (precond corresp s s') lL -> - exec e' s' t m lL mc (fun t' m' lL' mc' => - exists lH', states_compat lH' corresp' lL' /\ post t' m' lH' mc'). + exec PreSpill isRegZ e' s' t m lL mcL (fun t' m' lL' mcL' => + exists lH' mcH', states_compat lH' corresp' lL' /\ + (mcL' - mcL <= mcH' - mcH)%metricsH /\ + post t' m' lH' mcH'). Proof. induction 2; intros; match goal with @@ -1125,6 +1226,7 @@ Section CheckerCorrect. intros. edestruct H3 as (l' & P & F). 1: eassumption. eapply putmany_of_list_zip_states_compat in P. 2-3: eassumption. destruct P as (lL' & P & SC). eexists. split. 1: eassumption. intros. eauto. + repeat eexists; eauto; cost_solve. - (* Case exec.call *) rename binds0 into binds', args0 into args'. unfold check_funcs in H. @@ -1149,41 +1251,46 @@ Section CheckerCorrect. do 2 eexists. ssplit. * eapply states_compat_getmany; eassumption. * exact L4. - * eexists. split. 2: eassumption. exact SC. + * repeat eexists. 6: eassumption. 1: exact SC. + all: cost_solve. - (* Case exec.load *) - eauto 10 with checker_hints. + a. b. - (* Case exec.store *) - eauto 10 with checker_hints. + a. b. - (* Case exec.inlinetable *) - eauto 10 with checker_hints. + a. b. - (* Case exec.stackalloc *) eapply exec.stackalloc. 1: assumption. intros. eapply exec.weaken. + eapply H2; try eassumption. - eapply states_compat_precond. eapply states_compat_put. assumption. - + cbv beta. intros. fwd. eauto 10 with checker_hints. + eapply states_compat_precond. eapply states_compat_put; eassumption. + + cbv beta. intros. fwd. + eexists. eexists. do 2 (split; try eassumption). + eexists. eexists. do 2 (split; try eassumption). + b. - (* Case exec.lit *) - eauto 10 with checker_hints. + a. b. - (* Case exec.op *) - eauto 10 with checker_hints. - + a. + unfold assert_in_op, assert_in, assignment, check_regs in *; scost_hammer. (* could be faster *) - (* Case exec.set *) - eauto 10 with checker_hints. - + a. b. - (* Case exec.if_true *) eapply exec.if_true. 1: eauto using states_compat_eval_bcond. eapply exec.weaken. + eapply IHexec. 1: eassumption. eapply states_compat_precond. eassumption. - + cbv beta. intros. fwd. eexists. split. 2: eassumption. - eapply states_compat_extends. 2: eassumption. eapply extends_intersect_l. + + cbv beta. intros. fwd. eexists. eexists. split. 2: split. 3: eassumption. + 1: eapply states_compat_extends. 2: eassumption. 1: eapply extends_intersect_l. + eapply check_regs_cost_SIf; eauto. - (* Case exec.if_false *) eapply exec.if_false. 1: eauto using states_compat_eval_bcond. eapply exec.weaken. + eapply IHexec. 1: eassumption. eapply states_compat_precond. eassumption. - + cbv beta. intros. fwd. eexists. split. 2: eassumption. - eapply states_compat_extends. 2: eassumption. eapply extends_intersect_r. + + cbv beta. intros. fwd. eexists. eexists. split. 2: split. 3: eassumption. + 1: eapply states_compat_extends. 2: eassumption. 1: eapply extends_intersect_r. + eapply check_regs_cost_SIf; eauto. - (* Case exec.loop *) rename H4 into IH2, IHexec into IH1, H6 into IH12. match goal with @@ -1192,33 +1299,55 @@ Section CheckerCorrect. pose proof SC as SC0. unfold loop_inv in SC. rewrite E in SC. - eapply exec.loop. + eapply exec.loop with + (mid2 := (fun (t'0 : Semantics.trace) (m'0 : mem) (lL' : impLocals) (mcL' : MetricLog) => + exists (lH' : srcLocals) (mcH' : MetricLog), + states_compat lH' a1 lL' /\ + (exists mcHmid mcLmid, + mcLmid - mcL <= mcHmid - mc /\ + mcL' - mcLmid <= mcH' - mcHmid)%metricsH /\ + mid2 t'0 m'0 lH' mcH')). + eapply IH1. 1: eassumption. eapply states_compat_precond. exact SC. + cbv beta. intros. fwd. eauto using states_compat_eval_bcond_None. - + cbv beta. intros. fwd. eexists. split. 2: eauto using states_compat_eval_bcond_bw. assumption. - + cbv beta. intros. fwd. eapply IH2; eauto using states_compat_eval_bcond_bw. - eapply states_compat_precond. assumption. - + cbv beta. intros. fwd. eapply IH12. 1: eassumption. 1: eassumption. - eapply states_compat_extends. 2: eassumption. - pose proof defuel_loop_inv as P. - specialize P with (2 := E0). - specialize P with (2 := E2). - specialize (P corresp). - unfold loop_inv in P|-*. - rewrite E in P. rewrite E. - specialize (P eq_refl). - rewrite P. - eapply extends_intersect_r. + + cbv beta. intros. fwd. eexists. eexists. (* exists (exec.cost_SLoop_false isRegStr cond mcH'). *) + split. 2: split. 3: eauto using states_compat_eval_bcond_bw. 1: assumption. + eapply check_regs_cost_SLoop_false. 1: apply E0. all: eauto. + + cbv beta. intros. fwd. eapply exec.weaken. 1: eapply IH2; eauto using states_compat_eval_bcond_bw. + 1: eapply states_compat_precond; eassumption. + cbv beta. intros. fwd. eexists. eexists. split. 2: split. 1,3: eauto. + exists mcH'. exists mc'. + split; eauto. + + cbv beta. intros. fwd. eapply exec.weaken. 1: eapply IH12. 1: eassumption. 1: eassumption. + * eapply states_compat_extends. 2: eassumption. + pose proof defuel_loop_inv as P. + specialize P with (2 := E0). + specialize P with (2 := E2). + specialize (P corresp). + unfold loop_inv in P|-*. + rewrite E in P. rewrite E. + specialize (P eq_refl). + rewrite P. + eapply extends_intersect_r. + * cbv beta. intros. fwd. eexists. eexists. split. 2: split. 1: eauto. 2: eauto. + intros. + repeat (unfold check_bcond, assert_in, assignment in *; fwd). + clear -E0 E1 E2 H4p1p0 H4p1p1 H4p1 H4p3. + intros; unfold check_regs in *; cbn in *; unfold exec.cost_SLoop_true in *; try discr_match_success; + destr cond; destr cond0; destr (isRegStr x); destr (isRegZ x0); try (destr (isRegStr y)); try (destr (isRegZ y0)); + try discriminate; try discr_match_success. + all: cost_solve. - (* Case exec.seq *) rename H2 into IH2, IHexec into IH1. eapply exec.seq. + eapply IH1. 1: eassumption. eapply states_compat_precond. assumption. - + cbv beta. intros. fwd. - eapply IH2. 1: eassumption. 1: eassumption. - eapply states_compat_precond. assumption. - - (* Case exec.skip *) - eapply exec.skip. eauto. + + cbv beta. intros t' m' l' mcblah ?. fwd. + eapply IH2 in H2p2. 2,3: eauto using states_compat_precond. + eapply exec.weaken; eauto. + cbv beta. intros. fwd. exists lH'0. exists mcH'0. split. 2:split. 1,3: eauto. + b. + - (* case exec.skip *) + a. b. Qed. End CheckerCorrect. diff --git a/compiler/src/compiler/Spilling.v b/compiler/src/compiler/Spilling.v index b91c0beb5..14c656809 100644 --- a/compiler/src/compiler/Spilling.v +++ b/compiler/src/compiler/Spilling.v @@ -11,6 +11,7 @@ Require Import compiler.Registers. Require Import compiler.SeparationLogic. Require Import compiler.SpillingMapGoals. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Require Import compiler.FlatImpConstraints. Require Import coqutil.Tactics.autoforward. Require Import coqutil.Tactics.fwd. @@ -20,6 +21,8 @@ Open Scope Z_scope. Section Spilling. Notation stmt := (stmt Z). + Notation execpre := (exec PreSpill isRegZ). + Notation execpost := (exec PostSpill isRegZ). Definition zero := 0. Definition ra := 1. @@ -127,9 +130,9 @@ Section Spilling. | SLit x n => SLit (ires_reg x) n;; save_ires_reg x - | SOp x op y oz => - load_iarg_reg 1 y;; - match oz with + | SOp x op y oz => + load_iarg_reg 1 y;; + match oz with | Var z => load_iarg_reg 2 z;; SOp (ires_reg x) op (iarg_reg 1 y) (Var (iarg_reg 2 z)) | Const _ => SOp (ires_reg x) op (iarg_reg 1 y) oz end;; save_ires_reg x @@ -353,7 +356,7 @@ Section Spilling. valid_vars_src m s -> valid_vars_tgt (spill_stmt s). Proof. - + unfold valid_vars_src, valid_vars_tgt. induction s; simpl; intros; repeat match goal with @@ -456,13 +459,15 @@ Section Spilling. related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> fp < r <= maxvar /\ (r < a0 \/ a7 < r) -> map.get l1 r = Some v -> - (forall mc2, - related maxvar frame fpval t1 m1 l1 t2 m2 (map.put l2 (iarg_reg i r) v) -> - post t2 m2 (map.put l2 (iarg_reg i r) v) mc2) -> - exec e2 (load_iarg_reg i r) t2 m2 l2 mc2 post. + (related maxvar frame fpval t1 m1 l1 t2 m2 (map.put l2 (iarg_reg i r) v) -> + post t2 m2 (map.put l2 (iarg_reg i r) v) + (if isRegZ r then mc2 else (mkMetricLog 1 0 2 0 + mc2)%metricsH)) -> + execpost e2 (load_iarg_reg i r) t2 m2 l2 mc2 post. Proof. intros. unfold load_iarg_reg, stack_loc, iarg_reg, related in *. fwd. + assert (isRegZ (9 + i) = true) by (unfold isRegZ; blia). + assert (isRegZ fp = true) by (unfold isRegZ; (assert (fp = 5) by auto); blia). destr (32 <=? r). - eapply exec.load. + eapply get_sep. ecancel_assumption. @@ -472,7 +477,10 @@ Section Spilling. eapply map.get_split_r. 1,3: eassumption. destr (map.get mp r); [exfalso|reflexivity]. specialize H0p2 with (1 := E0). blia. - + eapply H3. + + unfold cost_load. + assert (isRegZ r = false) by (unfold isRegZ; blia); rewrite H4 in H3. + unfold spill_tmp in H3. rewrite H0; rewrite H1. + eapply H3. repeat match goal with | |- exists _, _ => eexists | |- _ /\ _ => split @@ -489,6 +497,7 @@ Section Spilling. destr (map.get lStack r); [exfalso|reflexivity]. specialize H0p3 with (1 := E0). blia. } + assert (isRegZ r = true) by (unfold isRegZ; blia); rewrite H4 in H3. eapply H3. repeat match goal with | |- exists _, _ => eexists @@ -497,49 +506,49 @@ Section Spilling. end. Qed. - Lemma load_iarg_reg_correct'(i: Z): forall r e2 t1 t2 m1 m2 l1 l2 mc1 mc2 post frame maxvar v fpval, - i = 1 \/ i = 2 -> - related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> - fp < r <= maxvar /\ (r < a0 \/ a7 < r) -> - map.get l1 r = Some v -> - post t1 m1 l1 mc1 -> - exec e2 (load_iarg_reg i r) t2 m2 l2 mc2 - (fun t2' m2' l2' mc2' => exists t1' m1' l1' mc1', - related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ post t1' m1' l1' mc1'). - Proof. - intros. - unfold load_iarg_reg, stack_loc, iarg_reg, related in *. fwd. - destr (32 <=? r). - - eapply exec.load. - + eapply get_sep. ecancel_assumption. - + eapply load_from_word_array. 1: ecancel_assumption. 2: blia. - eapply H0p6. 1: blia. - unfold sep in H0p4. fwd. - eapply map.get_split_r. 1,3: eassumption. - destr (map.get mp r); [exfalso|reflexivity]. - specialize H0p2 with (1 := E0). blia. - + repeat match goal with - | |- exists _, _ => eexists - | |- _ /\ _ => split - | |- _ => eassumption || reflexivity - end. - eapply put_tmp; eassumption. - - eapply exec.skip. - replace l2 with (map.put l2 r v) in H0p5|-*. 2: { - apply map.put_idemp. - edestruct (eq_sep_to_split l2) as (l2Rest & S22 & SP22). 1: ecancel_assumption. - eapply map.get_split_grow_r. 1: eassumption. - unfold sep in H0p4. destruct H0p4 as (lRegs' & lStack' & S2 & ? & ?). subst lRegs' lStack'. - eapply map.get_split_l. 1: exact S2. 2: assumption. - destr (map.get lStack r); [exfalso|reflexivity]. - specialize H0p3 with (1 := E0). blia. - } - repeat match goal with - | |- exists _, _ => eexists - | |- _ /\ _ => split - | |- _ => eassumption || reflexivity - end. - Qed. + (* Lemma load_iarg_reg_correct'(i: Z): forall r e2 t1 t2 m1 m2 l1 l2 mc1 mc2 post frame maxvar v fpval, *) + (* i = 1 \/ i = 2 -> *) + (* related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> *) + (* fp < r <= maxvar /\ (r < a0 \/ a7 < r) -> *) + (* map.get l1 r = Some v -> *) + (* post t1 m1 l1 mc1 -> *) + (* exec e2 (load_iarg_reg i r) t2 m2 l2 mc2 *) + (* (fun t2' m2' l2' mc2' => exists t1' m1' l1' mc1', *) + (* related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ post t1' m1' l1' mc1'). *) + (* Proof. *) + (* intros. *) + (* unfold load_iarg_reg, stack_loc, iarg_reg, related in *. fwd. *) + (* destr (32 <=? r). *) + (* - eapply exec.load. *) + (* + eapply get_sep. ecancel_assumption. *) + (* + eapply load_from_word_array. 1: ecancel_assumption. 2: blia. *) + (* eapply H0p6. 1: blia. *) + (* unfold sep in H0p4. fwd. *) + (* eapply map.get_split_r. 1,3: eassumption. *) + (* destr (map.get mp r); [exfalso|reflexivity]. *) + (* specialize H0p2 with (1 := E0). blia. *) + (* + repeat match goal with *) + (* | |- exists _, _ => eexists *) + (* | |- _ /\ _ => split *) + (* | |- _ => eassumption || reflexivity *) + (* end. *) + (* eapply put_tmp; eassumption. *) + (* - eapply exec.skip. *) + (* replace l2 with (map.put l2 r v) in H0p5|-*. 2: { *) + (* apply map.put_idemp. *) + (* edestruct (eq_sep_to_split l2) as (l2Rest & S22 & SP22). 1: ecancel_assumption. *) + (* eapply map.get_split_grow_r. 1: eassumption. *) + (* unfold sep in H0p4. destruct H0p4 as (lRegs' & lStack' & S2 & ? & ?). subst lRegs' lStack'. *) + (* eapply map.get_split_l. 1: exact S2. 2: assumption. *) + (* destr (map.get lStack r); [exfalso|reflexivity]. *) + (* specialize H0p3 with (1 := E0). blia. *) + (* } *) + (* repeat match goal with *) + (* | |- exists _, _ => eexists *) + (* | |- _ /\ _ => split *) + (* | |- _ => eassumption || reflexivity *) + (* end. *) + (* Qed. *) (* Note: if we wanted to use this lemma in subgoals created by exec.loop, new postcondition must not mention the original t2, m2, l2, mc2, (even though @@ -552,9 +561,10 @@ Section Spilling. related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> fp < r <= maxvar /\ (r < a0 \/ a7 < r) -> map.get l1 r = Some v -> - exec e2 (load_iarg_reg i r) t2 m2 l2 mc2 (fun t2' m2' l2' mc2' => + execpost e2 (load_iarg_reg i r) t2 m2 l2 mc2 (fun t2' m2' l2' mc2' => t2' = t2 /\ m2' = m2 /\ l2' = map.put l2 (iarg_reg i r) v /\ - related maxvar frame fpval t1 m1 l1 t2' m2' l2'). + related maxvar frame fpval t1 m1 l1 t2' m2' l2' /\ + (mc2' <= (if isRegZ r then mc2 else mkMetricLog 1 0 2 0 + mc2))%metricsH). Proof. intros. unfold load_iarg_reg, stack_loc, iarg_reg, related in *. fwd. @@ -572,7 +582,10 @@ Section Spilling. | |- _ /\ _ => split | |- _ => eassumption || reflexivity end. - eapply put_tmp; eassumption. + 1: eapply put_tmp; eassumption. + unfold cost_load. assert (isRegZ (9+i) = true) by (unfold isRegZ; blia); rewrite H0. + assert (fp = 5) by auto; rewrite H1; cbn. + destr (isRegZ r); solve_MetricLog. - eapply exec.skip. assert (l2 = map.put l2 r v) as F. { symmetry. apply map.put_idemp. @@ -588,19 +601,23 @@ Section Spilling. | |- _ /\ _ => split | |- _ => eassumption || reflexivity end. + destr (isRegZ r); solve_MetricLog. Qed. (* SOp does not create an up-to-date `related` before we invoke this one, because after SOp, `related` does not hold: the result is already in l1 and lStack, but not yet in stackwords. So we request the `related` that held *before* SOp, i.e. the one where the result is not yet in l1 and l2. *) - Lemma save_ires_reg_correct: forall e t1 t2 m1 m2 l1 l2 mc1 mc2 x v maxvar frame post fpval, - post t1 m1 (map.put l1 x v) mc1 -> + Lemma save_ires_reg_correct: forall e t1 t2 m1 m2 l1 l2 mc1 mc1' mc2 mc2' x v maxvar frame post fpval, + post t1 m1 (map.put l1 x v) mc1' -> related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> fp < x <= maxvar /\ (x < a0 \/ a7 < x) -> - exec e (save_ires_reg x) t2 m2 (map.put l2 (ires_reg x) v) mc2 - (fun t2' m2' l2' mc2' => exists t1' m1' l1' mc1', - related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ post t1' m1' l1' mc1'). + (mc2' - mc2 <= mc1' - (if isRegZ x then mc1 else mkMetricLog 1 1 1 0 + mc1))%metricsH -> + execpost e (save_ires_reg x) t2 m2 (map.put l2 (ires_reg x) v) mc2' + (fun t2' m2' l2' mc2'' => exists t1' m1' l1' mc1'', + related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ + post t1' m1' l1' mc1'' /\ + (mc2'' - mc2 <= mc1'' - mc1)%metricsH). Proof. intros. unfold save_ires_reg, stack_loc, ires_reg, related in *. fwd. @@ -645,12 +662,14 @@ Section Spilling. - eapply Nj. 1: blia. eauto. } 1: { unfold spill_tmp. eapply put_tmp; eauto. } - blia. + 1: blia. + unfold cost_store. unfold spill_tmp; cbn. + destr (isRegZ x); solve_MetricLog. } blia. - eapply exec.skip. - (* even though we did nothing, we have to reconstruct the `related` from the `related` that - held *before* the SOp *) + (* even though we did nothing, we have to reconstruct the `related` from the `related` that *) + (* held *before* the SOp *) repeat match goal with | |- exists _, _ => eexists | |- _ /\ _ => split @@ -686,19 +705,20 @@ Section Spilling. specialize H0p8 with (1 := H1). blia. } all: try eassumption. + destr (isRegZ x); solve_MetricLog. Qed. - (* SOp does not create an up-to-date `related` before we invoke this one, because after SOp, - `related` does not hold: the result is already in l1 and lStack, but not yet in stackwords. - So we request the `related` that held *before* SOp, i.e. the one where the result is not - yet in l1 and l2. *) + (* SOp does not create an up-to-date `related` before we invoke this one, because after SOp, *) + (* `related` does not hold: the result is already in l1 and lStack, but not yet in stackwords. *) + (* So we request the `related` that held *before* SOp, i.e. the one where the result is not *) + (* yet in l1 and l2. *) Lemma save_ires_reg_correct'': forall e t1 t2 m1 m2 l1 l2 mc2 x v maxvar frame post fpval, related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> fp < x <= maxvar /\ (x < a0 \/ a7 < x) -> - (forall t2' m2' l2' mc2', + (forall t2' m2' l2', related maxvar frame fpval t1 m1 (map.put l1 x v) t2' m2' l2' -> - post t2' m2' l2' mc2') -> - exec e (save_ires_reg x) t2 m2 (map.put l2 (ires_reg x) v) mc2 post. + post t2' m2' l2' (if isRegZ x then mc2 else mkMetricLog 1 1 1 0 + mc2)%metricsH) -> + execpost e (save_ires_reg x) t2 m2 (map.put l2 (ires_reg x) v) mc2 post. Proof. intros. unfold save_ires_reg, stack_loc, ires_reg, related in *. fwd. @@ -711,7 +731,12 @@ Section Spilling. eapply get_sep. ecancel_assumption. - rewrite map.get_put_same. reflexivity. - exact St. - - eapply H1. + - unfold cost_store. + assert (isRegZ (spill_tmp 1) = true) by auto; rewrite H. + assert (isRegZ fp = true) by auto; rewrite H0. + clear H H0. + destr (isRegZ x); try blia. + eapply H1. repeat match goal with | |- exists _, _ => eexists | |- _ /\ _ => split @@ -747,9 +772,10 @@ Section Spilling. } blia. - eapply exec.skip. + destr (isRegZ x); try blia. eapply H1. - (* even though we did nothing, we have to reconstruct the `related` from the `related` that - held *before* the SOp *) + (* even though we did nothing, we have to reconstruct the `related` from the `related` that *) + (* held *before* the SOp *) repeat match goal with | |- exists _, _ => eexists | |- _ /\ _ => split @@ -831,6 +857,22 @@ Section Spilling. intros. apply H. eapply hide_ll_arg_reg_ptsto_core; eassumption. Qed. + Fixpoint cost_set_vars_to_reg_range (args: list Z) (start : Z) (mc : MetricLog) : MetricLog := + match args with + | [] => mc + | x :: xs => (if isRegZ x then mkMetricLog 1 0 1 0 else mkMetricLog 1 1 1 0) + + cost_set_vars_to_reg_range xs (start + 1) mc + end. + + Lemma cost_set_vars_to_reg_range_commutes: + forall args start n m, + (cost_set_vars_to_reg_range args start (n + m) = n + cost_set_vars_to_reg_range args start m)%metricsH. + Proof. + induction args; trivial. + intros; cbn; destr (isRegZ a); cbn; rewrite IHargs; + do 2 rewrite MetricArith.add_assoc; rewrite (MetricArith.add_comm n); reflexivity. + Qed. + Lemma set_vars_to_reg_range_correct: forall args start argvs e t1 t2 m1 m2 l1 l1' l2 mc2 maxvar frame post fpval, related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> @@ -842,8 +884,9 @@ Section Spilling. Forall (fun x => fp < x <= maxvar /\ (x < a0 \/ a7 < x)) args -> (forall m2' l2' mc2', related maxvar frame fpval t1 m1 l1' t2 m2' l2' -> + mc2' = (cost_set_vars_to_reg_range args start mc2) -> post t2 m2' l2' mc2') -> - exec e (set_vars_to_reg_range args start) t2 m2 l2 mc2 post. + execpost e (set_vars_to_reg_range args start) t2 m2 l2 mc2 post. Proof. induction args; intros. - simpl. eapply exec.skip. fwd. eauto. @@ -861,50 +904,70 @@ Section Spilling. { eassumption. } eapply IHargs; try eassumption; try blia. (* establish related for IH: *) - unfold related. - eexists (map.put lStack a v), lRegs, _. - ssplit. - { reflexivity. } - { ecancel_assumption. } - { eassumption. } - { intros. rewrite map.get_put_dec in H. destr (a =? x0). 1: blia. eauto. } - { apply sep_comm. eapply sep_eq_put. 1: apply sep_comm; assumption. - intros lRegs' w ? G. subst lRegs'. - match goal with H: _ |- _ => specialize H with (1 := G) end. blia. } - { eassumption. } - { intros b A0 w B0. + * unfold related. + eexists (map.put lStack a v), lRegs, _. + ssplit. + { reflexivity. } + { ecancel_assumption. } + { eassumption. } + { intros. rewrite map.get_put_dec in H. destr (a =? x0). 1: blia. eauto. } + { apply sep_comm. eapply sep_eq_put. 1: apply sep_comm; assumption. + intros lRegs' w ? G. subst lRegs'. + match goal with H: _ |- _ => specialize H with (1 := G) end. blia. } + { eassumption. } + { intros b A0 w B0. rewrite map.get_put_dec in B0. destr (a =? b). 1: congruence. match goal with H: _ |- _ => eapply H end. 1: blia. match goal with H: _ |- _ => eapply H end. 1: blia. assumption. } - { blia. } + { blia. } + * intros. apply H6; auto. + cbn in *. destr (isRegZ start); try blia; destr (isRegZ a); try blia. + rewrite H0. unfold cost_store, isRegZ. cbn. + rewrite cost_set_vars_to_reg_range_commutes. + rewrite (proj2 (Z.leb_le start 31)) by assumption. + reflexivity. + eapply exec.set. { eassumption. } eapply IHargs; try eassumption; try blia. 2: { eapply map.getmany_of_list_put_diff. 2: eassumption. eapply List.not_In_Z_seq. blia. } - unfold related. eexists lStack, (map.put lRegs a v), _. - ssplit. - { reflexivity. } - { ecancel_assumption. } - { intros. rewrite map.get_put_dec in H. destr (a =? x). 1: blia. eauto. } - { eassumption. } - { eapply sep_eq_put. 1: assumption. - intros lStack' w ? G. subst lStack'. - match goal with H: _ |- _ => specialize H with (1 := G) end. blia. } - { apply sep_assoc. eapply sep_eq_put. 1: ecancel_assumption. - unfold ptsto, arg_regs. - intros l w (l_arg_regs & l_fpval & (? & ?) & ? & ?) G. subst. + * unfold related. eexists lStack, (map.put lRegs a v), _. + ssplit. + { reflexivity. } + { ecancel_assumption. } + { intros. rewrite map.get_put_dec in H. destr (a =? x). 1: blia. eauto. } + { eassumption. } + { eapply sep_eq_put. 1: assumption. + intros lStack' w ? G. subst lStack'. + match goal with H: _ |- _ => specialize H with (1 := G) end. blia. } + { apply sep_assoc. eapply sep_eq_put. 1: ecancel_assumption. + unfold ptsto, arg_regs. + intros l w (l_arg_regs & l_fpval & (? & ?) & ? & ?) G. subst. rewrite map.get_putmany_dec, map.get_put_dec, map.get_empty in G. destr (fp =? a). 1: unfold fp; blia. match goal with H: _ |- _ => specialize H with (1 := G) end. unfold a0, a7 in *. blia. } - { assumption. } - { assumption. } + { assumption. } + { assumption. } + * intros. apply H6; auto. + cbn in *. destr (isRegZ start); try blia; destr (isRegZ a); try blia. + rewrite H0. unfold cost_set, isRegZ. cbn. + rewrite cost_set_vars_to_reg_range_commutes. + rewrite (proj2 (Z.leb_le a 31)) by assumption. + rewrite (proj2 (Z.leb_le start 31)) by assumption. + reflexivity. Qed. + Fixpoint cost_set_reg_range_to_vars (start : Z) (args: list Z) (mc : MetricLog) : MetricLog := + match args with + | [] => mc + | x :: xs => (if isRegZ x then mkMetricLog 1 0 1 0 else mkMetricLog 1 0 2 0) + + cost_set_reg_range_to_vars (start + 1) xs mc + end. + Lemma set_reg_range_to_vars_correct: forall args argvs start e t1 t2 m1 m2 l1 l2 mc2 maxvar frame post fpval, related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> @@ -916,12 +979,14 @@ Section Spilling. (forall l2' mc2', related maxvar frame fpval t1 m1 l1 t2 m2 l2' -> map.getmany_of_list l2' (List.unfoldn (Z.add 1) (List.length args) start) = Some argvs -> + mc2' = (cost_set_reg_range_to_vars start args mc2) -> post t2 m2 l2' mc2') -> - exec e (set_reg_range_to_vars start args) t2 m2 l2 mc2 post. + execpost e (set_reg_range_to_vars start args) t2 m2 l2 mc2 post. Proof. induction args; intros. - - simpl. eapply exec.skip. eapply H5. 1: eassumption. simpl. - destruct argvs. 1: reflexivity. discriminate. + - simpl. eapply exec.skip. eapply H5. 1: eassumption. + + simpl. destruct argvs. 1: reflexivity. discriminate. + + trivial. - simpl. unfold set_reg_to_var, stack_loc. destruct argvs as [|v vs]. { unfold map.getmany_of_list in H4. cbn in H4. simp. @@ -956,6 +1021,10 @@ Section Spilling. ++ rewrite Z.add_comm. eapply map.getmany_of_list_put_diff. 2: eassumption. eauto using List.not_In_Z_seq with zarith. + -- cbn. destr (isRegZ start); destr (isRegZ a); cbn in *; try blia. + rewrite H6; unfold cost_load, isRegZ; cbn. + rewrite (proj2 (Z.leb_le start 31)) by assumption. + reflexivity. + eapply exec.seq_cps. eapply IHargs; try eassumption; try blia. intros. @@ -975,15 +1044,71 @@ Section Spilling. - rewrite Z.add_comm. eapply map.getmany_of_list_put_diff. 2: eassumption. eauto using List.not_In_Z_seq with zarith. } - unfold related. - repeat match goal with - | |- exists _, _ => eexists - | |- _ /\ _ => split - | |- _ => eassumption || reflexivity - end. - eapply put_arg_reg; try eassumption. blia. + -- unfold related. + repeat match goal with + | |- exists _, _ => eexists + | |- _ /\ _ => split + | |- _ => eassumption || reflexivity + end. + eapply put_arg_reg; try eassumption. blia. + -- cbn. destr (isRegZ start); destr (isRegZ a); cbn in *; try blia. + rewrite H6; unfold cost_set, isRegZ; cbn. + rewrite (proj2 (Z.leb_le a 31)) by assumption. + rewrite (proj2 (Z.leb_le start 31)) by assumption. + reflexivity. + Qed. + + Lemma cost_set_reg_range_to_vars_bound : forall args start mc len, + Z.of_nat (Datatypes.length args) <= len -> + (cost_set_reg_range_to_vars start args mc <= addMetricInstructions len (addMetricLoads (2 * len) mc))%metricsH. + Proof. + induction args. + - unfold cost_set_reg_range_to_vars, Z.of_nat. intros. simpl in H. solve_MetricLog. + - intros. + cbn [cost_set_reg_range_to_vars]. + specialize (IHargs (start+1) mc (Z.of_nat (Datatypes.length args)) (Z.le_refl _)). + subst. + simpl in H. + destruct (isRegZ a); solve_MetricLog. Qed. + Lemma cost_set_vars_to_reg_range_bound : forall args start mc len, + Z.of_nat (Datatypes.length args) <= len -> + (cost_set_vars_to_reg_range args start mc <= addMetricInstructions len (addMetricLoads len (addMetricStores len mc)))%metricsH. + Proof. + induction args. + - unfold cost_set_vars_to_reg_range, Z.of_nat. intros. simpl in H. solve_MetricLog. + - intros. + cbn [cost_set_vars_to_reg_range]. + specialize (IHargs (start+1) mc (Z.of_nat (Datatypes.length args)) (Z.le_refl _)). + subst. + simpl in H. + destruct (isRegZ a); solve_MetricLog. + Qed. + + (* pulled and modified from Coq.Program.Tactics *) + Ltac add_hypothesis p := + match type of p with + ?X => match goal with + | [ H : X |- _ ] => fail 1 + | _ => pose proof p + end + end. + + Ltac add_bounds := + repeat match goal with + | _: context[cost_set_reg_range_to_vars ?x ?y ?z] |- _ => + add_hypothesis (cost_set_reg_range_to_vars_bound y x z 8 ltac:(blia)) + | _: context[cost_set_vars_to_reg_range ?x ?y ?z] |- _ => + add_hypothesis (cost_set_vars_to_reg_range_bound x y z 8 ltac:(blia)) + | |- context[cost_set_reg_range_to_vars ?x ?y ?z] => + add_hypothesis (cost_set_reg_range_to_vars_bound y x z 8 ltac:(blia)) + | |- context[cost_set_vars_to_reg_range ?x ?y ?z] => + add_hypothesis (cost_set_vars_to_reg_range_bound x y z 8 ltac:(blia)) + end. + + (* end silly seeming section *) + Lemma grow_related_mem: forall maxvar frame t1 mSmall1 l1 t2 mSmall2 l2 mStack mCombined2 fpval, related maxvar frame fpval t1 mSmall1 l1 t2 mSmall2 l2 -> map.split mCombined2 mSmall2 mStack -> @@ -1150,24 +1275,37 @@ Section Spilling. Definition spilling_correct_for(e1 e2 : env)(s1 : stmt): Prop := forall (t1 : Semantics.trace) (m1 : mem) (l1 : locals) (mc1 : MetricLog) (post : Semantics.trace -> mem -> locals -> MetricLog -> Prop), - exec e1 s1 t1 m1 l1 mc1 post -> + execpre e1 s1 t1 m1 l1 mc1 post -> forall (frame : mem -> Prop) (maxvar : Z), valid_vars_src maxvar s1 -> forall (t2 : Semantics.trace) (m2 : mem) (l2 : locals) (mc2 : MetricLog) (fpval : word), related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> - exec e2 (spill_stmt s1) t2 m2 l2 mc2 + execpost e2 (spill_stmt s1) t2 m2 l2 mc2 (fun (t2' : Semantics.trace) (m2' : mem) (l2' : locals) (mc2' : MetricLog) => exists t1' m1' l1' mc1', related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ - post t1' m1' l1' mc1'). + post t1' m1' l1' mc1' /\ + (mc2' - mc2 <= mc1' - mc1)%metricsH). + + (* TODO tighter / non-fixed bound *) + Definition cost_spill_spec mc := + (mkMetricLog 100 100 100 100 + mc)%metricsH. Definition call_spec(e: env) '(argnames, retnames, fbody) - (t: Semantics.trace)(m: mem)(argvals: list word) - (post: Semantics.trace -> mem -> list word -> Prop): Prop := - forall l mc, map.of_list_zip argnames argvals = Some l -> - exec e fbody t m l mc (fun t' m' l' mc' => + (t: Semantics.trace)(m: mem)(argvals: list word)(mc: MetricLog) + (post: Semantics.trace -> mem -> list word -> MetricLog -> Prop): Prop := + forall l, map.of_list_zip argnames argvals = Some l -> + execpre e fbody t m l (cost_spill_spec mc) (fun t' m' l' mc' => + exists retvals, map.getmany_of_list l' retnames = Some retvals /\ + post t' m' retvals mc'). + + Definition call_spec_spilled(e: env) '(argnames, retnames, fbody) + (t: Semantics.trace)(m: mem)(argvals: list word)(mc: MetricLog) + (post: Semantics.trace -> mem -> list word -> MetricLog -> Prop): Prop := + forall l, map.of_list_zip argnames argvals = Some l -> + execpost e fbody t m l mc (fun t' m' l' mc' => exists retvals, map.getmany_of_list l' retnames = Some retvals /\ - post t' m' retvals). + post t' m' retvals mc'). (* In exec.call, there are many maps of locals involved: @@ -1195,14 +1333,16 @@ Section Spilling. what happens in the callee. TODO: actually use that lemma in case exec.call. Moreover, this lemma will also be used in the pipeline, where phases are composed based on the semantics of function calls. *) + Lemma spill_fun_correct_aux: forall e1 e2 argnames1 retnames1 body1 argnames2 retnames2 body2, spill_fun (argnames1, retnames1, body1) = Success (argnames2, retnames2, body2) -> spilling_correct_for e1 e2 body1 -> - forall argvals t m (post: Semantics.trace -> mem -> list word -> Prop), - call_spec e1 (argnames1, retnames1, body1) t m argvals post -> - call_spec e2 (argnames2, retnames2, body2) t m argvals post. + forall argvals t m mcH mcL (post: Semantics.trace -> mem -> list word -> MetricLog -> Prop), + call_spec e1 (argnames1, retnames1, body1) t m argvals mcH post -> + call_spec_spilled e2 (argnames2, retnames2, body2) t m argvals mcL + (fun t' m' l' mcL' => exists mcH', metricsLeq (mcL' - mcL) (mcH' - mcH) /\ post t' m' l' mcH'). Proof. - unfold call_spec, spilling_correct_for. intros * Sp IHexec * Ex lFL3 mc OL2. + unfold call_spec, spilling_correct_for. intros * Sp IHexec * Ex lFL3 OL2. unfold spill_fun in Sp. fwd. apply_in_hyps @map.getmany_of_list_length. apply_in_hyps @map.putmany_of_list_zip_sameLength. @@ -1271,7 +1411,7 @@ Section Spilling. 2: eapply Forall_le_max. cbv beta. subst maxvar'. clear. blia. } - intros mL4 lFL4 mcL4 R. + intros mL4 lFL4 mcL4 R Hcost. eapply exec.seq_cps. eapply exec.weaken. { eapply IHexec. 1: apply Ex. 2: exact R. @@ -1307,7 +1447,7 @@ Section Spilling. subst maxvar'. clear. blia. } { eassumption. } rename R into R0. - intros lFL6 mcL6 R GM. + intros lFL6 mcL6 R GM HCost. (* prove that if we remove the additional stack provided by exec.stackalloc and store the result vars back into the arg registers, the postcondition holds *) unfold related in R. fwd. rename lStack into lStack5, lRegs into lRegs5. @@ -1321,7 +1461,7 @@ Section Spilling. | |- exists _, _ => eexists | |- _ /\ _ => split end. - 4: eassumption. + 5: eassumption. 2: { unfold map.split. eauto. } @@ -1337,10 +1477,39 @@ Section Spilling. } blia. } { eassumption. } - Unshelve. - all: try assumption. + { + add_bounds. + unfold cost_stackalloc, cost_spill_spec in *. (* TODO XXX *) + destruct (isRegZ fp); solve_MetricLog. + } Qed. + + Lemma iarg_reg_isReg: forall i a, + (i <= 20) -> + (isRegZ (iarg_reg i a) = true). + Proof. + intros. unfold isRegZ, iarg_reg. destr (32 <=? a); unfold spill_tmp; blia. + Qed. + + Lemma ires_reg_isReg: forall r, + (isRegZ (ires_reg r) = true). + Proof. + intros. unfold isRegZ, ires_reg. destr (32 <=? r); unfold spill_tmp; blia. + Qed. + + Ltac isReg_helper := + match goal with + | |- context[(isRegZ (iarg_reg _ _))] => rewrite iarg_reg_isReg by blia + | H: context[(isRegZ (iarg_reg _ _))] |- _ => rewrite iarg_reg_isReg in H by blia + | |- context[(isRegZ (ires_reg _))] => rewrite ires_reg_isReg + | H: context[(isRegZ (ires_reg _))] |- _ => rewrite ires_reg_isReg in H + end. + + Ltac irs := cost_unfold; repeat isReg_helper; cost_solve. + Ltac sirs := scost_unfold; repeat isReg_helper; scost_solve. + + Lemma spilling_correct (e1 e2 : env) (Ev : spill_functions e1 = Success e2) (s1 : stmt) (t1 : Semantics.trace) @@ -1348,22 +1517,23 @@ Section Spilling. (l1 : locals) (mc1 : MetricLog) (post : Semantics.trace -> mem -> locals -> MetricLog -> Prop): - exec e1 s1 t1 m1 l1 mc1 post -> + execpre e1 s1 t1 m1 l1 mc1 post -> forall (frame : mem -> Prop) (maxvar : Z), valid_vars_src maxvar s1 -> forall (t2 : Semantics.trace) (m2 : mem) (l2 : locals) (mc2 : MetricLog) (fpval : word), related maxvar frame fpval t1 m1 l1 t2 m2 l2 -> - exec e2 (spill_stmt s1) t2 m2 l2 mc2 + execpost e2 (spill_stmt s1) t2 m2 l2 mc2 (fun (t2' : Semantics.trace) (m2' : mem) (l2' : locals) (mc2' : MetricLog) => exists t1' m1' l1' mc1', related maxvar frame fpval t1' m1' l1' t2' m2' l2' /\ - post t1' m1' l1' mc1'). + post t1' m1' l1' mc1' /\ + (mc2' - mc2 <= mc1' - mc1)%metricsH). Proof. induction 1; intros; cbn [spill_stmt valid_vars_src Forall_vars_stmt] in *; fwd. - (* exec.interact *) eapply exec.seq_cps. eapply set_reg_range_to_vars_correct; try eassumption; try (unfold a0, a7; blia). - intros *. intros R GM. clear l2 mc2 H4. + intros *. intros R GM CSet. clear l2 H4. unfold related in R. fwd. spec (subst_split (ok := mem_ok) m) as A. 1: eassumption. 1: ecancel_assumption. @@ -1401,12 +1571,18 @@ Section Spilling. { reflexivity. } { unfold a0, a7. blia. } { eassumption. } - { intros. do 4 eexists. split. 1: eassumption. eapply H2p1. - unfold map.split. split; [reflexivity|]. - move C at bottom. - unfold sep at 1 in C. destruct C as (mKeepL' & mRest & SC & ? & _). subst mKeepL'. - move H2 at bottom. unfold map.split in H2. fwd. - eapply map.shrink_disjoint_l; eassumption. } + { intros. do 4 eexists. split. 1: eassumption. split. + { eapply H2p1. + unfold map.split. split; [reflexivity|]. + move C at bottom. + unfold sep at 1 in C. destruct C as (mKeepL' & mRest & SC & ? & _). subst mKeepL'. + move H2 at bottom. unfold map.split in H2. fwd. + eapply map.shrink_disjoint_l; eassumption. } + cbn in *. subst. + add_bounds. + cost_solve. + (* cost_SInteract constraint: prespill - postspill >= (...32...) i think? *) + } (* related for set_vars_to_reg_range_correct: *) unfold related. eexists _, _, _. ssplit. @@ -1452,7 +1628,7 @@ Section Spilling. apply_in_hyps @map.getmany_of_list_length. apply_in_hyps @map.putmany_of_list_zip_sameLength. eapply set_reg_range_to_vars_correct; try eassumption || (unfold a0, a7 in *; blia). - intros lCL2 ? ? ?. + intros lCL2 ? ? ? ?. assert (bytes_per_word = 4 \/ bytes_per_word = 8) as B48. { unfold bytes_per_word. destruct width_cases as [E' | E']; rewrite E'; cbv; auto. } @@ -1521,7 +1697,7 @@ Section Spilling. 2: eapply Forall_le_max. cbv beta. subst maxvar'. clear. blia. } - intros mL4 lFL4 mcL4 R. + intros mL4 lFL4 mcL4 R CSet. eapply exec.seq_cps. eapply exec.weaken. { eapply IHexec. 2: exact R. @@ -1537,7 +1713,7 @@ Section Spilling. } cbv beta. subst maxvar'. blia. } - cbv beta. intros tL5 mL5 lFL5 mcL5 (tH5 & mH5 & lFH5 & mcH5 & R5 & OC). + cbv beta. intros tL5 mL5 lFL5 mcL5 (tH5 & mH5 & lFH5 & mcH5 & R5 & OC & Hmetrics). match goal with | H: context[outcome], A: context[outcome] |- _ => specialize H with (1 := A); move H at bottom; rename H into Q @@ -1561,7 +1737,7 @@ Section Spilling. subst maxvar'. clear. blia. } { eassumption. } rename R into R0. - intros lFL6 mcL6 R GM. + intros lFL6 mcL6 R GM ?. (* prove that if we remove the additional stack provided by exec.stackalloc and store the result vars back into the caller's registers, states are still related and postcondition holds *) @@ -1621,12 +1797,18 @@ Section Spilling. { unfold a0, a7. blia. } { eassumption. } { intros m22 l22 mc22 R22. do 4 eexists. split. 1: eassumption. - eassumption. } + split; try eassumption. + subst. + move Hmetrics at bottom. + add_bounds. + cost_solve. + (* cost_SCall constraint: prespill - postspill >= (...66...) i think? *) + } - (* exec.load *) eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H3. intros. + intros. eapply exec.seq_cps. pose proof H2 as A. unfold related in A. fwd. unfold Memory.load, Memory.load_Z, Memory.load_bytes in *. fwd. @@ -1640,12 +1822,11 @@ Section Spilling. + eassumption. + eassumption. + blia. + + irs. - (* exec.store *) - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H4. intros. - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H3. intros. - pose proof H3 as A. unfold related in A. fwd. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. + pose proof H5 as A. unfold related in A. fwd. unfold Memory.store, Memory.store_Z, Memory.store_bytes in *. fwd. edestruct (@sep_def _ _ _ m2 (eq m)) as (m' & m2Rest & Sp & ? & ?). 1: ecancel_assumption. unfold map.split in Sp. subst. fwd. @@ -1655,18 +1836,18 @@ Section Spilling. { unfold Memory.store, Memory.store_Z, Memory.store_bytes. unfold Memory.load_bytes in *. erewrite map.getmany_of_tuple_in_disjoint_putmany; eauto. } - do 4 eexists. split. 2: eassumption. - unfold related. - repeat match goal with - | |- exists _, _ => eexists - | |- _ /\ _ => split - end. - all: try eassumption || reflexivity. - spec store_bytes_sep_hi2lo as A. 1: eassumption. - all: ecancel_assumption. + do 4 eexists. split. 2: split. 2: eassumption. + + unfold related. + repeat match goal with + | |- exists _, _ => eexists + | |- _ /\ _ => split + end. + all: try eassumption || reflexivity. + spec store_bytes_sep_hi2lo as A. 1: eassumption. + all: ecancel_assumption. + + irs. - (* exec.inlinetable *) - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H4. intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.seq_cps. eapply exec.inlinetable. { unfold ires_reg, iarg_reg, spill_tmp, fp, a0, a7 in *. destr (32 <=? x); destr (32 <=? i); try blia. } @@ -1676,6 +1857,7 @@ Section Spilling. + eassumption. + eassumption. + blia. + + irs. - (* exec.stackalloc *) rename H1 into IH. eapply exec.stackalloc. 1: assumption. @@ -1693,73 +1875,83 @@ Section Spilling. | |- _ /\ _ => split end. 1,4,3,2: eassumption. + irs. - (* exec.lit *) eapply exec.seq_cps. eapply exec.lit. eapply save_ires_reg_correct. + eassumption. + eassumption. + blia. + + irs. - (* exec.op *) unfold exec.lookup_op_locals in *. eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H3. intros. destruct_one_match; fwd. + clear H3. intros. destruct_one_match; fwd. { eapply exec.seq_cps. eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. + clear H2. intros. eapply exec.op. { eapply get_iarg_reg_1; eauto with zarith. } { unfold exec.lookup_op_locals in *. apply map.get_put_same. } - { eapply save_ires_reg_correct; (eassumption || blia). } + { eapply save_ires_reg_correct; (try eassumption || blia). + sirs. } } { eapply exec.seq_cps. eapply exec.op. { apply map.get_put_same. } { unfold exec.lookup_op_locals in *. reflexivity. } - { eapply save_ires_reg_correct; (eassumption || blia). } + { eapply save_ires_reg_correct; (try eassumption || blia). + sirs. } } - (* exec.set *) - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.seq_cps. eapply exec.set. 1: apply map.get_put_same. eapply save_ires_reg_correct. + eassumption. + eassumption. + blia. + + irs. - (* exec.if_true *) unfold prepare_bcond. destr cond; cbn [ForallVars_bcond eval_bcond spill_bcond] in *; fwd. + eapply exec.seq_assoc. - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2. intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.if_true. { cbn. erewrite get_iarg_reg_1 by eauto with zarith. rewrite map.get_put_same. congruence. } - eapply IHexec; eassumption. - + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. + eapply exec.weaken. + * eapply IHexec; eassumption. + * cbv beta; intros; fwd. exists t1', m1', l1', mc1'. split. 2: split. all: try eassumption. + sirs. + + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.if_true. { cbn. rewrite map.get_put_same. rewrite word.eqb_ne by assumption. reflexivity. } - eapply IHexec; eassumption. + eapply exec.weaken. + * eapply IHexec; eassumption. + * cbv beta; intros; fwd. exists t1', m1', l1', mc1'. split. 2: split. all: try eassumption. + sirs. - (* exec.if_false *) unfold prepare_bcond. destr cond; cbn [ForallVars_bcond eval_bcond spill_bcond] in *; fwd. + eapply exec.seq_assoc. - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. - eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2. intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.if_false. { cbn. erewrite get_iarg_reg_1 by eauto with zarith. rewrite map.get_put_same. congruence. } - eapply IHexec; eassumption. - + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). - clear mc2 H2. intros. + eapply exec.weaken. + * eapply IHexec; eassumption. + * cbv beta; intros; fwd. exists t1', m1', l1', mc1'. split. 2: split. all: try eassumption. + sirs. + + eapply exec.seq_cps. eapply load_iarg_reg_correct; (blia || eassumption || idtac). intros. eapply exec.if_false. { cbn. rewrite map.get_put_same. rewrite word.eqb_eq; reflexivity. } - eapply IHexec; eassumption. + eapply exec.weaken. + * eapply IHexec; eassumption. + * cbv beta; intros; fwd. exists t1', m1', l1', mc1'. split. 2: split. all: try eassumption. + sirs. - (* exec.loop *) rename IHexec into IH1, H3 into IH2, H5 into IH12. eapply exec.loop_cps. @@ -1779,15 +1971,19 @@ Section Spilling. erewrite get_iarg_reg_1 by eauto with zarith. rewrite map.get_put_same. eexists. split; [reflexivity|]. split; intros. - * do 4 eexists. split. - -- exact H3p6. + * do 4 eexists. split. 2: split. + -- exact H3p8. -- eapply H1. 1: eassumption. cbn. rewrite E, E0. congruence. + -- sirs. * eapply exec.weaken. 1: eapply IH2. -- eassumption. -- cbn. rewrite E, E0. congruence. -- eassumption. -- eassumption. - -- cbv beta. intros. fwd. eauto 10. (* IH12 *) + -- cbv beta. intros. fwd. eapply exec.weaken. + ++ eapply IH12; try eassumption. repeat split; eauto; blia. + ++ cbv beta; intros; fwd. exists t1'1, m1'1, l1'1, mc1'1. split. 2:split. all: try eassumption. + sirs. + specialize H0 with (1 := H3p1). cbn in H0. fwd. eapply exec.weaken. { eapply load_iarg_reg_correct''; (blia || eassumption || idtac). @@ -1795,31 +1991,39 @@ Section Spilling. cbv beta. intros. fwd. cbn [eval_bcond spill_bcond]. rewrite map.get_put_same. eexists. split; [reflexivity|]. split; intros. - * do 4 eexists. split. - -- exact H3p5. + * do 4 eexists. split. 2: split. + -- exact H3p6. -- eapply H1. 1: eassumption. cbn. rewrite E. congruence. + -- sirs. * eapply exec.weaken. 1: eapply IH2. -- eassumption. -- cbn. rewrite E. congruence. -- eassumption. -- eassumption. - -- cbv beta. intros. fwd. eauto 10. (* IH12 *) + -- cbv beta. intros. fwd. eapply exec.weaken. + ++ eapply IH12; try eassumption. repeat split; eauto; blia. + ++ cbv beta; intros; fwd. exists t1'1, m1'1, l1'1, mc1'1. split. 2:split. all: try eassumption. + sirs. - (* exec.seq *) cbn in *. fwd. rename H1 into IH2, IHexec into IH1. eapply exec.seq. + eapply IH1. 1: eassumption. eauto 15. - + cbn. intros. fwd. eapply IH2. 1,2: eassumption. eauto 15. + + cbn. intros. fwd. eapply exec.weaken. + * eapply IH2; eassumption. + * cbv beta. intros. fwd. exists t1'0, m1'0, l1'0, mc1'0. split. 2:split. all: try eassumption. + solve_MetricLog. - (* exec.skip *) - eapply exec.skip. eauto 20. + eapply exec.skip. exists t, m, l, mc. repeat split; eauto; solve_MetricLog. Qed. Lemma spill_fun_correct: forall e1 e2 argnames1 retnames1 body1 argnames2 retnames2 body2, spill_functions e1 = Success e2 -> spill_fun (argnames1, retnames1, body1) = Success (argnames2, retnames2, body2) -> - forall argvals t m (post: Semantics.trace -> mem -> list word -> Prop), - call_spec e1 (argnames1, retnames1, body1) t m argvals post -> - call_spec e2 (argnames2, retnames2, body2) t m argvals post. + forall argvals t m mcH mcL (post: Semantics.trace -> mem -> list word -> MetricLog -> Prop), + call_spec e1 (argnames1, retnames1, body1) t m argvals mcH post -> + call_spec_spilled e2 (argnames2, retnames2, body2) t m argvals mcL + (fun t' m' l' mcL' => exists mcH', metricsLeq (mcL' - mcL) (mcH' - mcH) /\ post t' m' l' mcH'). Proof. intros. eapply spill_fun_correct_aux; try eassumption. unfold spilling_correct_for. diff --git a/compiler/src/compiler/ToplevelLoop.v b/compiler/src/compiler/ToplevelLoop.v index 7b82edff5..99a63f567 100644 --- a/compiler/src/compiler/ToplevelLoop.v +++ b/compiler/src/compiler/ToplevelLoop.v @@ -297,16 +297,15 @@ Section Pipeline1. compile_ext_call_length_ignores_positions as P. unfold runsTo in P. specialize P with (argvals := []) - (post := fun t' m' retvals => isReady spec t' m' /\ goodTrace spec t') + (post := fun t' m' retvals mc' => isReady spec t' m' /\ goodTrace spec t') (fname := "init"%string). edestruct P as (init_rel_pos & G & P'); clear P; cycle -1. 1: eapply P' with (p_funcs := word.add loop_pos (word.of_Z 8)) (Rdata := R). all: simpl_MetricRiscvMachine_get_set. - 11: { + 12: { unfold hl_inv in init_code_correct. move init_code_correct at bottom. do 4 eexists. split. 1: eassumption. split. 1: reflexivity. - intros mc. eapply ExprImp.weaken_exec. - refine (init_code_correct _ _ _). replace (datamem_start spec) with (heap_start ml) by congruence. @@ -314,7 +313,7 @@ Section Pipeline1. exact HMem. - cbv beta. intros * _ HP. exists []. split. 1: reflexivity. exact HP. } - 10: { unfold compile. rewrite_match. reflexivity. } + 11: { unfold compile. rewrite_match. reflexivity. } all: try eassumption. { apply stack_length_divisible. } { cbn. clear CP. @@ -324,6 +323,7 @@ Section Pipeline1. { destruct mlOk. solve_divisibleBy4. } { reflexivity. } { reflexivity. } + { reflexivity. } unfold machine_ok. clear P'. rewrite GetPos in G. fwd. @@ -409,6 +409,7 @@ Section Pipeline1. * eapply iff1ToEq. unfold init_sp_insts, init_insts, loop_insts, backjump_insts. wwcancel. + Unshelve. exact MetricLogging.EmptyMetricLog. Qed. Lemma ll_inv_is_invariant: forall (st: MetricRiscvMachine), @@ -505,14 +506,13 @@ Section Pipeline1. unfold runsTo in P. specialize P with (argvals := []) (fname := "loop"%string) - (post := fun t' m' retvals => isReady spec t' m' /\ goodTrace spec t'). + (post := fun t' m' retvals mc' => isReady spec t' m' /\ goodTrace spec t'). edestruct P as (loop_rel_pos & G & P'); clear P; cycle -1. 1: eapply P' with (p_funcs := word.add loop_pos (word.of_Z 8)) (Rdata := R) (ret_addr := word.add loop_pos (word.of_Z 4)). - 11: { + 12: { move loop_body_correct at bottom. do 4 eexists. split. 1: eassumption. split. 1: reflexivity. - intros mc. eapply ExprImp.weaken_exec. - eapply loop_body_correct; eauto. - cbv beta. intros * _ HP. exists []. split. 1: reflexivity. exact HP. @@ -526,6 +526,7 @@ Section Pipeline1. { subst loop_pos init_pos. destruct mlOk. solve_divisibleBy4. } { reflexivity. } { reflexivity. } + { reflexivity. } unfold loop_pos, init_pos. unfold machine_ok. unfold_RiscvMachine_get_set. @@ -588,6 +589,7 @@ Section Pipeline1. * wcancel_assumption. * eapply rearrange_footpr_subset. 1: eassumption. wwcancel. + Unshelve. exact MetricLogging.EmptyMetricLog. Qed. Lemma ll_inv_implies_prefix_of_good: forall st, diff --git a/compiler/src/compiler/UseImmediate.v b/compiler/src/compiler/UseImmediate.v index 71f4a7e81..7d9114c2e 100644 --- a/compiler/src/compiler/UseImmediate.v +++ b/compiler/src/compiler/UseImmediate.v @@ -6,6 +6,7 @@ Require Import coqutil.Tactics.fwd. Require Import String. Require Import compiler.UseImmediateDef. Require Import bedrock2.MetricLogging. +Require Import bedrock2.MetricCosts. Local Notation var := String.string (only parsing). @@ -25,39 +26,121 @@ Section WithArguments. morphism (word.ring_morph (word := word)), constants [word_cst]). - Local Hint Constructors exec: core. + Local Notation exec := (exec PreSpill isRegStr). + + Open Scope MetricH_scope. + + Ltac tandem H := ( + repeat match goal with + | |- exists _, _ => let x := fresh in destruct H as (x&?); exists x + | |- _ /\ _ => destruct H as (?&H); split; try eassumption + | _ => let x := fresh in intro x; specialize (H x) + end + ). + + Ltac finish := ( + simpl; + intros *; + repeat match goal with + | |- (exists _, _ /\ _) -> _ => intros (?&?&?) + | |- _ /\ _ => split; eauto + | |- exists _, _ => eexists + end; + repeat match goal with + | H : _ <= _ |- _ => revert H + end; + clear; + FlatImp.scost_hammer + ). + + (* TODO these two lemmas are somewhat slow *) + Lemma op_cost_y : forall x0 y v0 mcH' mc v mcL lit, + exec.cost_SOp isRegStr x0 y (Var v0) EmptyMetricLog - EmptyMetricLog <= + mcH' - cost_lit isRegStr lit mc -> + exec.cost_SOp isRegStr x0 y (Const v) (cost_lit isRegStr lit mcL) - mcL <= + mcH' - mc. + Proof. finish. Qed. + + Lemma op_cost_v0 : forall x0 y v0 mcH' mc v mcL lit, + exec.cost_SOp isRegStr x0 y (Var v0) EmptyMetricLog - EmptyMetricLog <= + mcH' - cost_lit isRegStr lit mc -> + exec.cost_SOp isRegStr x0 v0 (Const v) (cost_lit isRegStr lit mcL) - mcL <= + mcH' - mc. + Proof. finish. Qed. Lemma useImmediate_correct_aux: forall eH eL, (useimmediate_functions is5BitImmediate is12BitImmediate) eH = Success eL -> forall sH t m mcH lH post, exec eH sH t m lH mcH post -> - exec eL (useImmediate is5BitImmediate is12BitImmediate sH) t m lH mcH post. + forall mcL, + exec eL (useImmediate is5BitImmediate is12BitImmediate sH) t m lH mcL + (fun t' m' l' mcL' => exists mcH', metricsLeq (mcL' - mcL) (mcH' - mcH) /\ post t' m' l' mcH'). Proof. - induction 2. - (* most cases stay the same *) - all: try solve [simpl; eauto]. - - (* SCall *) - { simpl. - eapply @exec.call; try eassumption. - assert (exists v2, (useimmediate_function is5BitImmediate is12BitImmediate) (params, rets, fbody) = Success v2 /\ map.get eL fname = Some v2). - { eapply map.try_map_values_fw. - - simpl in H; eapply H. - - eassumption. - } - destruct H5. destruct H5. simpl in H5. inversion H5. fwd. - eassumption. - } - - (* SSeq *) - { simpl. - repeat (match goal with - | |- context[match ?x with _ => _ end] => destr x - | |- context[if ?x then _ else _ ] => destr x - end; - try solve [eapply @exec.seq; eassumption]); - simpl in *. + induction 2; try solve [ + simpl; econstructor; eauto; + tandem H3; + finish + ]. + + - (* SCall *) + simpl; econstructor; eauto. + { unfold useimmediate_functions in H. + destruct (map.try_map_values_fw _ _ _ H _ _ H0) as (?&[=Huseimm]&Hmap). + rewrite <- Huseimm in Hmap. + exact Hmap. } + cbv beta. + intros * (?&?&Houtcome). + destruct (H4 _ _ _ _ Houtcome) as (retvs&l'&Hpost). + exists retvs, l'. + tandem Hpost. + finish. + + - (* SStackalloc *) + simpl; econstructor; eauto. + tandem H2. + eapply exec.weaken; [eauto|]. + simpl; intros * (?&?&?&?&?&?&?). + finish. + + - (* SIf true *) + simpl; econstructor; eauto. + eapply exec.weaken; [eauto|]. + finish. + + - (* SIf false *) + simpl; intro; eapply exec.if_false; eauto. + eapply exec.weaken; [eauto|]. + finish. + + - (* SLoop *) + simpl; econstructor; eauto; simpl. + { intros * (?&?&?); eauto. } + { intros * (?&?&?) **. finish. } + { intros * (?&?&?) **. + eapply exec.weaken; [eauto|]. + simpl; intros * (?&?&?). + instantiate (1 := fun t m l MC1 => exists MC2, MC1 - mcL <= MC2 - mc /\ mid2 t m l MC2). + finish. } + { intros * (?&?&?). + eapply exec.weaken; [eauto|]. + finish. } + + - (* SSeq *) + simpl. intro. + + repeat ( + match goal with + | |- context[match ?x with _ => _ end] => destr x + | |- context[if ?x then _ else _ ] => destr x + end; + try solve [ + eapply @exec.seq; eauto; simpl; + intros * (?&?&?); + eapply @exec.weaken; [eauto|]; + finish + ] + ). all: eapply @exec.seq_cps; eapply @exec.lit. @@ -70,10 +153,10 @@ Section WithArguments. all: match goal with | H: ?mid _ _ _ _, - H0: forall t m l mc, - ?mid t m l mc -> exec ?eL _ _ _ _ _ ?post - |- exec ?eL _ _ _ _ _ ?post - => apply H0 in H; inversion H + H0: forall t m l mc, + ?mid t m l mc -> forall mcL, exec ?eL _ _ _ _ mcL _ + |- exec ?eL _ _ _ _ _ _ + => specialize (H0 _ _ _ _ H EmptyMetricLog); inversion H0 end. all: simpl in *; @@ -82,13 +165,17 @@ Section WithArguments. => rewrite map.get_put_same in H; fwd end. - all: eapply @exec.op; simpl in *; [ eassumption | reflexivity | try eassumption ]. + all: eapply @exec.op; simpl in *; [ eassumption | reflexivity | ]. + + all: exists mcH'; split; [solve [eapply op_cost_y; eauto | eapply op_cost_v0; eauto]|eauto]. - { rewrite word.add_comm. assumption. } - { replace (word.add y' (word.of_Z (- v))) with (word.sub y' (word.of_Z v)) by ring. assumption. } - { rewrite word.and_comm. assumption. } - { rewrite word.or_comm. assumption. } - { rewrite word.xor_comm. assumption. } - } + + rewrite word.add_comm. assumption. + + replace (word.add y' (word.of_Z (-v))) with (word.sub y' (word.of_Z v)) by ring. assumption. + + rewrite word.and_comm. assumption. + + rewrite word.or_comm. assumption. + + rewrite word.xor_comm. assumption. + + (* this is slightly slow also *) Qed. + End WithArguments. diff --git a/compiler/src/compiler/memory_mapped_ext_calls_compiler.v b/compiler/src/compiler/memory_mapped_ext_calls_compiler.v index 5a421ec6a..e9cb05b57 100644 --- a/compiler/src/compiler/memory_mapped_ext_calls_compiler.v +++ b/compiler/src/compiler/memory_mapped_ext_calls_compiler.v @@ -422,7 +422,7 @@ Section MMIO. { eapply MapEauto.only_differ_union_l. eapply MapEauto.only_differ_put. cbn. left. reflexivity. } - { MetricsToRiscv.solve_MetricLog. } + { unfold MetricCosts.cost_interact in *; MetricsToRiscv.solve_MetricLog. } { eapply map.put_extends. eassumption. } { eapply map.forall_keys_put; assumption. } { rewrite map.get_put_diff; eauto. unfold RegisterNames.sp. @@ -537,7 +537,7 @@ Section MMIO. eapply map.split_empty_r. reflexivity. } { reflexivity. } { eapply MapEauto.only_differ_refl. } - { MetricsToRiscv.solve_MetricLog. } + { unfold MetricCosts.cost_interact in *; MetricsToRiscv.solve_MetricLog. } { eassumption. } { assumption. } { assumption. }