Skip to content

Commit

Permalink
wip prettifying end-to-end metrics theorem
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgruetter committed Nov 14, 2024
1 parent 2cdebfe commit 7c80fde
Showing 1 changed file with 158 additions and 31 deletions.
189 changes: 158 additions & 31 deletions compiler/src/compilerExamples/MetricLightbulbInvariant.v
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,79 @@ Definition loop_progress t t' dmc :=
(lan9250_recv_packet_too_long _ ioh) \/
((TracePredicate.any +++ (spi_timeout _)) ioh).

(* We want to prove a theorem that's purely about RISC-V execution, so we want
to express everything in terms of RISC-V metrics (metricL notation scope),
but that doesn't have all the operators yet *)
Notation RiscvMetrics := Platform.MetricLogging.MetricLog.

Definition metricAdd(metric: RiscvMetrics -> Z) finalM initialM: Z :=
Z.sub (metric finalM) (metric initialM).
Definition metricsAdd := Platform.MetricLogging.metricsOp metricAdd.
Infix "+" := metricsAdd : MetricL_scope.

Definition metricsMul (n : Z) (m : RiscvMetrics) :=
riscv.Platform.MetricLogging.mkMetricLog
(n * riscv.Platform.MetricLogging.instructions m)
(n * riscv.Platform.MetricLogging.stores m)
(n * riscv.Platform.MetricLogging.loads m)
(n * riscv.Platform.MetricLogging.jumps m).
Infix "*" := metricsMul : MetricL_scope.

Definition mc_spi_write_const :=
riscv.Platform.MetricLogging.mkMetricLog 348 227 381 204.
Definition mc_spi_read_const :=
riscv.Platform.MetricLogging.mkMetricLog 199 116 217 103.
Definition mc_spi_xchg_const :=
(mc_spi_write_const + mc_spi_read_const +
riscv.Platform.MetricLogging.mkMetricLog 410 402 414 401)%metricsL.
Definition mc_spi_mul := riscv.Platform.MetricLogging.mkMetricLog 157 109 169 102.
Definition lightbulb_handle_cost :=
riscv.Platform.MetricLogging.mkMetricLog 552 274 639 203.

(*
let dmc := metricsSub (MetricsToRiscv.raiseMetrics dmc) (cost_call PreSpill (cost_lit (prefix "reg_") ""%string loop_overhead)) in
*)

(* TODO fix implicit-ness status at definition site *)
Arguments OP {word}.
Arguments lan9250_recv {word}.
Arguments lightbulb_packet_rep {word}.
Arguments gpio_set {word}.
Arguments lan9250_recv_packet_too_long {word}.
Arguments spi_timeout {word}.
Arguments lan9250_recv_no_packet {word}.

(*------desired code width for papers: max 88 columns---------------------------------*)

Section WithMetricsScope. Open Scope metricsL.

Definition handle_request_spec(t t': trace)(mc mc': RiscvMetrics) :=
exists dt, t' = dt ++ t /\
exists ioh, metric_SPI.mmio_trace_abstraction_relation ioh dt /\ (
(* Case 1: Received packet with valid command: *)
(exists packet cmd,
(lan9250_recv packet +++ gpio_set 23 cmd) ioh /\
lightbulb_packet_rep cmd packet /\
((mc-mc' <= (60+7*length packet)*mc_spi_xchg_const +
lightbulb_handle_cost + (length ioh)*mc_spi_mul))) \/
(* Case 2: Received invalid packet: *)
(exists packet,
(lan9250_recv packet) ioh /\
not (exists cmd, lightbulb_packet_rep cmd packet) /\
((mc-mc' <= (60+7*length packet)*mc_spi_xchg_const +
lightbulb_handle_cost + (length ioh)*mc_spi_mul))) \/
(* Case 3: Polled, but no new packet was available: *)
(lan9250_recv_no_packet ioh) \/
(* Case 4: Received too long packet *)
(lan9250_recv_packet_too_long ioh) \/
(* Case 5: SPI protocol timeout *)
((TracePredicate.any +++ spi_timeout) ioh)).

End WithMetricsScope.

Compute
let length_packet := 1520%nat in
metricsAdd (cost_call PreSpill (cost_lit (prefix "reg_") ""%string loop_overhead)) (((60+7*length_packet)*mc_spi_xchg_const + lightbulb_handle_cost + 0*mc_spi_mul)).
bedrock2.MetricLogging.metricsAdd (cost_call PreSpill (cost_lit (prefix "reg_") ""%string loop_overhead)) (((60+7*length_packet)*bedrock2Examples.metric_SPI.mc_spi_xchg_const + bedrock2Examples.metric_lightbulb.lightbulb_handle_cost + 0*bedrock2Examples.metric_SPI.mc_spi_mul)).

(*
= {|
Expand Down Expand Up @@ -138,25 +208,58 @@ Proof.
rewrite List.firstn_length. Lia.lia.
Qed.

Derive _tt SuchThat (
forall (initial : MetricRiscvMachine) R,
valid_machine initial ->
getLog initial = [] ->
regs_initialized.regs_initialized (getRegs initial) ->
getNextPc initial = word.add (getPc initial) (word.of_Z 4) ->
getPc initial = code_start ml ->
(program RV32IM (code_start ml) (fst (fst out)) * R *
LowerPipeline.mem_available (heap_start ml) (heap_pastend ml) *
LowerPipeline.mem_available (stack_start ml) (stack_pastend ml))%sep
(getMem initial) ->
subset (footpr (program RV32IM (code_start ml) (fst (fst out))))
(of_list (getXAddrs initial)) ->
_)
As metric_lightbulb_correct.
Module riscv.
(* just implicit arguments, as inferred by Coq in previous version
that used Ltac... :P *)
Definition run1 :=
(@ToplevelLoop.run1 (2 ^ Nat.log2 32) BW32 (Naive.word (2 ^ Nat.log2 32))
(@map (2 ^ Nat.log2 32) (Naive.word (2 ^ Nat.log2 32)) word32_ok Init.Byte.byte)
(Z_keyed_SortedListMap.Zkeyed_map (Naive.word (2 ^ Nat.log2 32)))
(FreeMonad.free.free
(@MetricMaterializeRiscvProgram.action 32 BW32 (Naive.word (2 ^ Nat.log2 32)))
(@MetricMaterializeRiscvProgram.result 32 BW32 (Naive.word (2 ^ Nat.log2 32))))
(@FreeMonad.free.Monad_free
(@MetricMaterializeRiscvProgram.action 32 BW32 (Naive.word (2 ^ Nat.log2 32)))
(@MetricMaterializeRiscvProgram.result 32 BW32 (Naive.word (2 ^ Nat.log2 32))))
(@MetricMaterializeRiscvProgram.MetricMaterialize 32 BW32
(Naive.word (2 ^ Nat.log2 32)))
(@MetricMinimalMMIO.MetricMinimalMMIOPrimitivesParams 32 BW32
(Naive.word (2 ^ Nat.log2 32))
(@map (2 ^ Nat.log2 32) (Naive.word (2 ^ Nat.log2 32)) word32_ok
Init.Byte.byte)
(Z_keyed_SortedListMap.Zkeyed_map (Naive.word (2 ^ Nat.log2 32)))
(@FE310ExtSpec.FE310_mmio 32 BW32 (Naive.word (2 ^ Nat.log2 32))
(@map (2 ^ Nat.log2 32) (Naive.word (2 ^ Nat.log2 32)) word32_ok
Init.Byte.byte))) RV32IM).
End riscv.

Import coqutil.Semantics.OmniSmallstepCombinators.

Axiom TODO: False.

Lemma metric_lightbulb_correct: forall (initial : MetricRiscvMachine) R,
valid_machine initial ->
getLog initial = [] ->
regs_initialized.regs_initialized (getRegs initial) ->
getNextPc initial = word.add (getPc initial) (word.of_Z 4) ->
getPc initial = code_start ml ->
(program RV32IM (code_start ml) (fst (fst out)) * R *
LowerPipeline.mem_available (heap_start ml) (heap_pastend ml) *
LowerPipeline.mem_available (stack_start ml) (stack_pastend ml))%sep
(getMem initial) ->
subset (footpr (program RV32IM (code_start ml) (fst (fst out))))
(of_list (getXAddrs initial)) ->
eventually riscv.run1
(successively riscv.run1
(fun s s' : MetricRiscvMachine =>
handle_request_spec (getLog s) (getLog s')
(getMetrics s) (getMetrics s'))) initial.
Proof.

intros.
unshelve refine (

eapply eventually_weaken.
1: refine (
let bedrock2_invariant t m := exists buf R,
(Separation.sep (Array.array Scalars.scalar8 (word.of_Z 1) (word.of_Z buffer_addr) buf) R) m /\
Z.of_nat (Datatypes.length buf) = 1520 in
Expand Down Expand Up @@ -186,9 +289,8 @@ _
_
_
ltac:(unfold initial_conditions; Tactics.ssplit)
)
;
shelve_unifiable.
).

1:vm_compute; reflexivity.
1:vm_compute; reflexivity.
all : cycle 1.
Expand All @@ -212,6 +314,12 @@ all : cycle -1.
1:eassumption.
1:apply Z.lt_le_incl; vm_compute; reflexivity.
1:apply Z.lt_le_incl; vm_compute; reflexivity.
{
unfold loop_progress, handle_request_spec.
intros.
(* needs weakening for successively and some more massaging *)
case TODO.
}
{
intros.
edestruct link_lightbulb_init as (?&?&?&D&?&E&X).
Expand All @@ -220,8 +328,8 @@ all : cycle -1.
eapply MetricSemantics.exec.weaken; [exact X|].
cbv beta; intros * ?; intros (?&?&?&?&?&?&?&?&?); subst.
cbv [bedrock2_invariant].
match goal with H: LowerPipeline.mem_available _ _ _ |- _ => revert H end. generalize m0. eassumption. }
{ cbv [bedrock2_invariant loop_progress].
eapply mem_available_to_seplog. assumption. }
{ cbv [bedrock2_invariant handle_request_spec].
intros; eapply MetricWeakestPreconditionProperties.sound_cmd.
repeat MetricProgramLogic.straightline.
{ eapply MetricSemantics.weaken_call; [eapply link_lightbulb_loop|cbv beta]; try eassumption.
Expand Down Expand Up @@ -250,8 +358,6 @@ all :
MetricLogging.unfold_MetricLog; MetricLogging.simpl_MetricLog;
unfold_MetricLog; simpl_MetricLog);
intuition try blia. } }

Unshelve. exact mem_available_to_seplog.
Qed.

Import OmniSmallstepCombinators.
Expand All @@ -260,7 +366,7 @@ Check metric_lightbulb_correct.
(*
metric_lightbulb_correct
: forall (initial : MetricRiscvMachine)
(R : map word32 Init.Byte.byte -> Prop),
(R : map (Naive.word (2 ^ Nat.log2 32)) Init.Byte.byte -> Prop),
valid_machine initial ->
getLog initial = [] ->
regs_initialized.regs_initialized (getRegs initial) ->
Expand All @@ -272,11 +378,11 @@ metric_lightbulb_correct
(getMem initial) ->
subset (footpr (program RV32IM (code_start ml) (fst (fst out))))
(of_list (getXAddrs initial)) ->
eventually ToplevelLoop.run1
(successively ToplevelLoop.run1
eventually riscv.run1
(successively riscv.run1
(fun s s' : MetricRiscvMachine =>
loop_progress (getLog s) (getLog s')
(getMetrics s' - getMetrics s))) initial
handle_request_spec (getLog s) (getLog s') (getMetrics s) (getMetrics s')))
initial
*)

Print Assumptions metric_lightbulb_correct.
Expand All @@ -287,9 +393,30 @@ PropExtensionality.propositional_extensionality :
functional_extensionality_dep :
forall (A : Type) (B : A -> Type) (f g : forall x : A, B x),
(forall x : A, f x = g x) -> f = g
TODO : False
used in metric_lightbulb_correct to prove
successively riscv.run1
(fun s s' : MetricRiscvMachine =>
exists dt : list Semantics.LogItem,
getLog s' = dt ++ getLog s /\
(exists ioh : list OP,
metric_SPI.mmio_trace_abstraction_relation ioh dt /\
((exists (packet : list Init.Byte.byte) (cmd : bool),
(lan9250_recv packet +++ gpio_set 23 cmd) ioh /\
lightbulb_packet_rep cmd packet /\
(getMetrics s - getMetrics s' <=
(60 + 7 * Datatypes.length packet) * mc_spi_xchg_const +
lightbulb_handle_cost + Datatypes.length ioh * mc_spi_mul)%metricsL) \/
(exists packet : list Init.Byte.byte,
lan9250_recv packet ioh /\
~ (exists cmd : bool, lightbulb_packet_rep cmd packet) /\
(getMetrics s - getMetrics s' <=
(60 + 7 * Datatypes.length packet) * mc_spi_xchg_const +
lightbulb_handle_cost + Datatypes.length ioh * mc_spi_mul)%metricsL) \/
lan9250_recv_no_packet ioh \/
lan9250_recv_packet_too_long ioh \/ (any +++ spi_timeout) ioh))) final
*)


(*------desired code width for papers: max 85 columns------------------------------*)
(* OLD STUFF BELOW, may be useful for instantiating some assumptions
Expand Down

0 comments on commit 7c80fde

Please sign in to comment.