diff --git a/silveroak-opentitan/hmac/hw/Sha256Properties.v b/silveroak-opentitan/hmac/hw/Sha256Properties.v index 85a94347c..ed1118d03 100644 --- a/silveroak-opentitan/hmac/hw/Sha256Properties.v +++ b/silveroak-opentitan/hmac/hw/Sha256Properties.v @@ -320,11 +320,12 @@ Definition sha256_padder_invariant /\ done = padder_done (* ...and if msg_complete is true, we must have processed the whole message *) /\ (if msg_complete then length msg <= index else index = length msg) - /\ (if out_valid - then - (* if output is valid, we must have processed at least one word *) - 4 <= index - else True) + (* ...and if output is valid, we must have processed at least one word *) + /\ (if out_valid then 4 <= index else True) + (* ...and if we're in the emit_bit state, it must be the case that the message + length is 0 mod 4 (otherwise we attach the 1 bit to the final byte and move + straight to flushing) *) + /\ (if (state =? padder_emit_bit_value)%N then length msg mod 4 = 0 else True) /\ (if done then (* if we're done, we must be in the padder_waiting state *) @@ -409,17 +410,12 @@ Lemma step_bvslice {n start len} (x : denote_type (BitVec n)) : = (tt, N.land (N.shiftr x (N.of_nat start)) (N.ones (N.of_nat len))). Proof. reflexivity. Qed. Hint Rewrite @step_bvslice using solve [eauto] : stepsimpl. - (* TODO: move *) Lemma length_N_to_bytes n bs : length (BigEndianBytes.N_to_bytes n bs) = n. Admitted. Hint Rewrite @length_N_to_bytes : push_length. (* TODO: move *) -Lemma length_bytes_to_Ns_upper_bound n bs : - length (BigEndianBytes.bytes_to_Ns n bs) * n < length bs + n. -Admitted. -(* TODO: move *) Lemma padded_message_bytes_length msg : length (SHA256.padded_msg_bytes msg) = padded_message_size msg. Admitted. @@ -428,9 +424,6 @@ Hint Rewrite @padded_message_bytes_length : push_length. Lemma padded_message_bytes_longer_than_input msg : length msg + 9 <= padded_message_size msg. Admitted. -Lemma padded_message_longer_than_input msg : - length (BigEndianBytes.bytes_to_Ns 4 msg) < length (SHA256.padded_msg msg). -Admitted. (* TODO: move *) Lemma min_padded_message_size msg : 64 <= padded_message_size msg. Admitted. @@ -438,9 +431,6 @@ Admitted. Lemma padded_message_size_modulo msg : padded_message_size msg mod 64 = 0. Admitted. (* TODO: move *) -Lemma padded_message_min_length msg : 16 <= length (SHA256.padded_msg msg). -Admitted. -(* TODO: move *) (* Adding data cannot decrease padded message size *) Lemma padded_message_size_mono msg data : padded_message_size msg <= padded_message_size (msg ++ data). @@ -559,7 +549,7 @@ Lemma expected_padder_state_cases msg (msg_complete padder_done : bool) index : \/ (expected_padder_state msg msg_complete padder_done index = padder_flushing_value /\ msg_complete = true /\ padder_done = false - /\ index < padded_message_size msg - 8) + /\ length msg < index < padded_message_size msg - 8) (* ...or we're in the writing_length state and the padder's expected output at index i part of the length *) \/ (expected_padder_state msg msg_complete padder_done index = padder_writing_length_value @@ -731,8 +721,6 @@ Proof. clear=false data_valid=true *) - pose proof length_bytes_to_Ns_upper_bound 4 msg. - pose proof padded_message_longer_than_input msg. pose proof padded_message_bytes_longer_than_input msg. pose proof min_padded_message_size msg. lazymatch goal with @@ -764,6 +752,16 @@ Proof. repeat destruct_one_match; logical_simplify; subst; push_length; try lia. } { (* if output is valid, then new index must be at least 4 *) repeat destruct_one_match; lia. } + { (* if we're in the emit_bit state, then the length of the message was 0 + mod 4 *) + cbv [expected_padder_state] in *. + destruct padder_done, out_valid, is_final; logical_simplify; subst; + boolsimpl; cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value + padder_emit_bit_value padder_writing_length_value + negb andb orb]. + all:repeat first [ discriminate | destruct_one_match | destruct_one_match_hyp ]. + all:try tauto. + all:push_length; prove_by_zify. } { cbv [expected_padder_state] in *. destruct padder_done, out_valid, is_final; logical_simplify; subst. all:repeat (destruct_one_match_hyp; try discriminate). @@ -808,8 +806,6 @@ Proof. clear=false data_valid=false *) - pose proof length_bytes_to_Ns_upper_bound 4 msg. - pose proof padded_message_longer_than_input msg. pose proof padded_message_bytes_longer_than_input msg. pose proof min_padded_message_size msg. ssplit. @@ -856,6 +852,26 @@ Proof. repeat destruct_one_match; logical_simplify; subst; try lia; cbv [expected_padder_state] in *; repeat destruct_one_match_hyp; congruence. } + { (* if state is emit_bit, length of message is 0 mod 4 *) + destruct padder_done; logical_simplify; subst; + rewrite ?N.eqb_refl in *; + cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value + padder_emit_bit_value padder_writing_length_value + negb andb orb]; [ tauto | ]. + pose proof + expected_padder_state_cases msg msg_complete false index + ltac:(eauto) ltac:(eauto) ltac:(eauto) as padder_state_cases. + let H := fresh in + destruct padder_state_cases as [H|[H|[H|H]]]; + logical_simplify; subst; + lazymatch goal with H : expected_padder_state _ _ _ _ = _ |- _ => + rewrite H in * end. + all:cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value + padder_emit_bit_value padder_writing_length_value + negb andb orb]. + all:repeat + first [ discriminate | tauto + | destruct_one_match | destruct_one_match_hyp ]. } { (* entire clause for what happens if we're done or not done *) destruct padder_done; logical_simplify; subst; rewrite ?N.eqb_refl; @@ -930,7 +946,7 @@ Definition sha256_padder_spec (msg : list Byte.byte) (msg_complete padder_done : bool) (index : nat) : denote_type (Bit ** sha_word ** Bit) := let '(data_valid, (data, (is_final, (final_length, (consumer_ready, (clear,_)))))) := input in - let '(done, (out, (out_valid, (state, (len, current_offset))))) := state in + let '(_, (out, (out_valid, (_, (_, _))))) := state in (* expected result as words *) let expected_words := SHA256.padded_msg @@ -941,24 +957,34 @@ Definition sha256_padder_spec then false else if consumer_ready then if padder_done - then data_valid (* we were previously done and got new valid data *) - else true (* we're partway through processing message *) + then data_valid (* valid only if we got new valid data *) + else if msg_complete + then true (* message is done, produce valid output always *) + else data_valid (* we're partway through processing message *) else out_valid (* repeat previous output if consumer is not ready *) in let out := if clear then 0%N else if consumer_ready - then nth word_index expected_words 0%N (* output matches expected value *) + then if padder_done + then if data_valid + then nth word_index expected_words 0%N (* start processing new message *) + else 0%N (* done; default output *) + else if msg_complete + then nth word_index expected_words 0%N (* output matches expected value *) + else if data_valid + then nth word_index expected_words 0%N (* output matches expected value *) + else 0%N (* default output; wait for next word of message *) else out (* repeat previous output if consumer is not ready *) in let done := if clear then true else if consumer_ready - then - (* either this was the last word, or we were previously done and didn't - start a new message this step *) - ((word_index =? length expected_words - 1) - || (padder_done && negb data_valid))%bool + then if padder_done + then negb data_valid + else if msg_complete + then (word_index =? length expected_words - 1) (* done iff this was the last word *) + else false else padder_done (* stay in same state if consumer not ready *) in (out_valid, (out, done)). @@ -1014,6 +1040,51 @@ Proof. Qed. Hint Rewrite nth_padding_succ : push_nth. +Lemma nth_padding_nonzero msg i : + 0 < i -> nth i (SHA256.padding msg) x00 = x00. +Proof. + destruct i; [ lia | ]. intros. + apply nth_padding_succ. +Qed. + +Lemma mul_div_exact_r a b : + b <> 0 -> a mod b = 0 -> a / b * b = a. +Proof. intros. prove_by_zify. Qed. +Lemma add_sub_cancel a b : a + b - a = b. +Proof. lia. Qed. +Hint Rewrite add_sub_cancel : natsimpl. + +Lemma N_to_byte_equiv x y : + (x mod 256 = y mod 256)%N -> N_to_byte x = N_to_byte y. +Proof. + intro Heq. cbv [N_to_byte]. compute_expr (2 ^ 8)%N. + rewrite Heq. reflexivity. +Qed. + +Lemma nth_N_to_bytes i n x : + i < n -> + nth i (BigEndianBytes.N_to_bytes n x) x00 + = N_to_byte (N.shiftr (N.land x (N.ones (8 * N.of_nat n))) + (N.of_nat (n - 1 - i) * 8)). +Proof. + intros. cbv [BigEndianBytes.N_to_bytes]. + push_nth; natsimpl. apply N_to_byte_equiv. + change 256%N with (2 ^ 8)%N. + apply N.bits_inj; intro j. + push_Ntestbit; boolsimpl. + destr (j N.testbit x n = false. + Proof. + intros. destr (x =? 0)%N; subst; [ push_Ntestbit; reflexivity | ]. + apply N.bits_above_log2. apply N.log2_lt_pow2; lia. + Qed. +End N. + Local Ltac testbit_crush := repeat lazymatch goal with | |- context [N.eqb ?x ?y] => destr (N.eqb x y); try lia; subst @@ -1021,6 +1092,7 @@ Local Ltac testbit_crush := | _ => first [ progress (push_Ntestbit; boolsimpl) | reflexivity ] end. + Lemma step_sha256_padder input state msg msg_complete padder_done index : sha256_padder_pre input msg msg_complete padder_done index -> sha256_padder_invariant state msg msg_complete padder_done index -> @@ -1062,8 +1134,6 @@ Proof. cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value padder_emit_bit_value padder_writing_length_value]. { (* data_valid=true *) - pose proof length_bytes_to_Ns_upper_bound 4 msg. - pose proof padded_message_longer_than_input msg. pose proof padded_message_bytes_longer_than_input msg. pose proof min_padded_message_size msg. lazymatch goal with @@ -1072,8 +1142,6 @@ Proof. pose proof padded_message_size_mono msg data; pose proof min_padded_message_size (firstn n data); pose proof min_padded_message_size data; - pose proof padded_message_longer_than_input (firstn n data); - pose proof padded_message_longer_than_input data; pose proof padded_message_bytes_longer_than_input (msg ++ firstn n data); pose proof padded_message_size_modulo (msg ++ firstn n data) end. @@ -1111,9 +1179,7 @@ Proof. change 8388608%N with (2 ^ 23)%N. change 128%N with (2 ^ 7)%N. (* solve the boolean parts of the tuple *) - do 2 f_equal; - [ | symmetry; apply Nat.eqb_neq; autorewrite with push_length in *; - prove_by_zify ]. + do 2 f_equal; [ ]. repeat destruct_one_match; subst; try lia. all: try lazymatch goal with @@ -1175,23 +1241,20 @@ Proof. { (* padder_done=false *) destruct is_final. { (* padder_done=false, is_final=true *) + cbn [Nat.add]. + compute_expr (N.of_nat 8). compute_expr (N.of_nat 16). - compute_expr (N.of_nat (16 + 16)). + compute_expr (N.of_nat 24). + compute_expr (N.of_nat 32). change 32768%N with (2 ^ 15)%N. change 8388608%N with (2 ^ 23)%N. change 128%N with (2 ^ 7)%N. (* solve the boolean parts of the tuple *) - do 2 f_equal; - [ | symmetry; apply Nat.eqb_neq; autorewrite with push_length in *; - prove_by_zify ]. + do 2 f_equal; [ ]. rewrite nth_padded_msg. - replace ((length msg / 4) * 4) with (length msg) - by (etransitivity; [ apply Nat.div_mod with (y:=4); lia | ]; lia). + rewrite mul_div_exact_r by lia. cbv [SHA256.padded_msg_bytes]. rewrite !app_assoc_reverse. push_nth. natsimpl. - repeat lazymatch goal with |- context [ ?x + ?y - ?x ] => - replace (x + y - x) with y by lia - end. autorewrite with push_length in *. repeat destruct_one_match; subst; try lia. all: @@ -1234,17 +1297,11 @@ Proof. all:destr (i - replace (x + y - x) with y by lia - end. lazymatch goal with | H : (?data < 2 ^ ?n)%N |- context [?data] => replace data with (N.land data (N.ones n)) @@ -1266,5 +1323,93 @@ Proof. destr (i + pose proof + expected_padder_state_cases msg mc pd i + ltac:(lia) ltac:(eauto) ltac:(eauto) + as padder_state_cases + end. + let H := fresh in + destruct padder_state_cases as [H|[H|[H|H]]]; + logical_simplify; subst; + lazymatch goal with H : expected_padder_state _ _ _ _ = _ |- _ => + rewrite H in * end; + rewrite ?N.eqb_refl in *; + cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value + padder_emit_bit_value padder_writing_length_value + negb andb orb] in *. + { (* state=padder_waiting *) + destruct msg_complete; logical_simplify; subst; [ discriminate | ]. + reflexivity. } + { (* state=padder_emit_bit *) + do 2 f_equal; + [ | symmetry; apply Nat.eqb_neq; push_length; prove_by_zify ]. + rewrite nth_padded_msg. rewrite mul_div_exact_r by lia. + cbv [SHA256.padded_msg_bytes]. push_nth. natsimpl. + rewrite !app_nth1 by (push_length; prove_by_zify). + push_nth. reflexivity. } + { (* state=padder_flushing *) + do 2 f_equal; + [ | symmetry; apply Nat.eqb_neq; push_length; prove_by_zify ]. + rewrite nth_padded_msg. rewrite mul_div_exact_r by lia. + cbv [SHA256.padded_msg_bytes]. push_nth. natsimpl. + rewrite !app_nth1 by (push_length; prove_by_zify). + rewrite !nth_padding_nonzero by lia. + reflexivity. } + { (* state=writing_length *) + push_length. + do 2 f_equal; + [ | + repeat lazymatch goal with + | |- context [Nat.eqb ?x ?y] => destr (Nat.eqb x y); try lia + | |- context [N.eqb ?x ?y] => destr (N.eqb x y); try lia + | H : context [Nat.eqb ?x ?y] |- _ => destr (Nat.eqb x y); try lia + | H : context [N.eqb ?x ?y] |- _ => destr (N.eqb x y); try lia + end; + try discriminate; prove_by_zify ]. + rewrite nth_padded_msg. rewrite mul_div_exact_r by lia. + cbv [SHA256.padded_msg_bytes]. push_nth. natsimpl. + rewrite !nth_N_to_bytes by (push_length; prove_by_zify). + replace (SHA256.l msg) with (N.shiftl (N.of_nat (length msg)) 3) + by apply N.shiftl_mul_pow2. + (* helpful assertion for length truncation *) + assert (2 ^ 61 * 8 = 2 ^ 64)%N by reflexivity. + rewrite !N.land_ones with (n:=64%N). + rewrite (N.mod_small (N.of_nat (length msg)) (2^64)%N) by lia. + rewrite (N.mod_small (N.shiftl (N.of_nat (length msg)) _) (2^64)%N) + by (rewrite N.shiftl_mul_pow2; change (2 ^ N.of_nat 3)%N with 8%N; + lia). + rewrite <-!N.land_ones. + assert (if (current_offset =? 15)%N + then index = padded_message_size msg - 4 + else current_offset = 14%N /\ index = padded_message_size msg - 8) + by (destruct_one_match; prove_by_zify). + destr (current_offset =? 15)%N; + logical_simplify; subst; + cbn [N.eqb Pos.eqb padder_waiting_value padder_flushing_value + padder_emit_bit_value padder_writing_length_value + negb andb orb]. + all:cbv [BigEndianBytes.concat_bytes]; cbn [fold_left]. + all:rewrite !N_to_byte_to_N; cbn [Byte.to_N]. + all:rewrite <-!N.land_ones with (n:=8%N). + all:apply N.bits_inj; intro i. + all:push_Ntestbit; boolsimpl. + all:push_length. + all:change (N.of_nat 0) with 0%N; rewrite ?N.add_0_r. + all:change (N.of_nat 3) with 3%N. + all:destr (i