From ba299d83397a60be07bd492004523b62c2d30cbd Mon Sep 17 00:00:00 2001 From: Hannes Mehnert Date: Sat, 29 Jun 2024 07:00:07 +0200 Subject: [PATCH] provide ciphers with {de,en}crypt_into functionality (#231) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Mirage_crypto.Block.ECB with {de,en}crypt_into Also provide unsafe_{en,de}crypt_into for further performance. * Mirage_crypto.Block.CBC now has {de,en}crypt_into functionality This may avoid buffer allocations. There are as well unsafe functions for those feeling bounds checks are unnecessary. * counters: add an offset parameter * Mirage_crypto.Block.CTR with {de,en}crypt_into * GCM and ChaCha have {de,en}crypt_into now * CCM16 with {de,en}crypt_into * minor adjustments to speed * Apply suggestions from code review Co-authored-by: Reynir Björnsson * revise bounds checks (cc @reynir @palainp), also check off >= 0 * revise block_size check * update documentation, esp off < 0 * poly1305: mac_into appropriate bounds checks, also unsafe_mac_into * ccm: remove maclen argument, and ensure tag_size = block_size * add tailcall annotations, remove an argument from ccm's loop --------- Co-authored-by: Reynir Björnsson --- bench/speed.ml | 116 +++++++++-- src/aead.ml | 12 ++ src/ccm.ml | 71 ++++--- src/chacha20.ml | 116 +++++++---- src/cipher_block.ml | 375 ++++++++++++++++++++++++++---------- src/cipher_stream.ml | 4 +- src/mirage_crypto.mli | 266 ++++++++++++++++++++++--- src/native.ml | 12 +- src/native/ghash_ctmul.c | 4 +- src/native/ghash_generic.c | 4 +- src/native/ghash_pclmul.c | 6 +- src/native/mirage_crypto.h | 4 +- src/native/misc.c | 4 +- src/native/misc_sse.c | 8 +- src/native/poly1305-donna.c | 8 +- src/poly1305.ml | 30 ++- 16 files changed, 787 insertions(+), 253 deletions(-) diff --git a/bench/speed.ml b/bench/speed.ml index 90d44425..3b1c90e5 100644 --- a/bench/speed.ml +++ b/bench/speed.ml @@ -45,6 +45,15 @@ let throughput title f = Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" size (bw /. mb) iters time +let throughput_into ?(add = 0) title f = + Printf.printf "\n* [%s]\n%!" title ; + sizes |> List.iter @@ fun size -> + Gc.full_major () ; + let dst = Bytes.create (size + add) in + let (iters, time, bw) = burn (f dst) size in + Printf.printf " % 5d: %04f MB/s (%d iters in %.03f s)\n%!" + size (bw /. mb) iters time + let count_period = 10. let count f n = @@ -347,55 +356,128 @@ let benchmarks = [ fst ecdh_shares); bm "chacha20-poly1305" (fun name -> - let key = Mirage_crypto.Chacha20.of_secret (Mirage_crypto_rng.generate 32) + let key = Chacha20.of_secret (Mirage_crypto_rng.generate 32) and nonce = Mirage_crypto_rng.generate 8 in - throughput name (Mirage_crypto.Chacha20.authenticate_encrypt ~key ~nonce)) ; + throughput_into ~add:Chacha20.tag_size name + (fun dst cs -> Chacha20.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))) ; + + bm "chacha20-poly1305-unsafe" (fun name -> + let key = Chacha20.of_secret (Mirage_crypto_rng.generate 32) + and nonce = Mirage_crypto_rng.generate 8 in + throughput_into ~add:Chacha20.tag_size name + (fun dst cs -> Chacha20.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))) ; bm "aes-128-ecb" (fun name -> let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + throughput_into name + (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-192-ecb" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-192-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-256-ecb" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in + throughput_into name (fun dst cs -> AES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-256-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in + throughput_into name (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-ecb-unsafe" (fun name -> + let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 16) in + throughput_into name + (fun dst cs -> AES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-cbc-e" (fun name -> let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) and iv = Mirage_crypto_rng.generate 16 in - throughput name (fun cs -> AES.CBC.encrypt ~key ~iv cs)) ; + throughput_into name + (fun dst cs -> AES.CBC.encrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-e-unsafe" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput_into name + (fun dst cs -> AES.CBC.unsafe_encrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-e-unsafe-inplace" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput name + (fun cs -> + let b = Bytes.unsafe_of_string cs in + AES.CBC.unsafe_encrypt_into_inplace ~key ~iv b ~dst_off:0 (String.length cs))) ; bm "aes-128-cbc-d" (fun name -> let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) and iv = Mirage_crypto_rng.generate 16 in - throughput name (fun cs -> AES.CBC.decrypt ~key ~iv cs)) ; + throughput_into name + (fun dst cs -> AES.CBC.decrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-cbc-d-unsafe" (fun name -> + let key = AES.CBC.of_secret (Mirage_crypto_rng.generate 16) + and iv = Mirage_crypto_rng.generate 16 in + throughput_into name + (fun dst cs -> AES.CBC.unsafe_decrypt_into ~key ~iv cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-ctr" (fun name -> let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret and ctr = Mirage_crypto_rng.generate 16 |> AES.CTR.ctr_of_octets in - throughput name (fun cs -> AES.CTR.encrypt ~key ~ctr cs)) ; + throughput_into name (fun dst cs -> AES.CTR.encrypt_into ~key ~ctr cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "aes-128-ctr-unsafe" (fun name -> + let key = Mirage_crypto_rng.generate 16 |> AES.CTR.of_secret + and ctr = Mirage_crypto_rng.generate 16 |> AES.CTR.ctr_of_octets in + throughput_into name (fun dst cs -> AES.CTR.unsafe_encrypt_into ~key ~ctr cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "aes-128-gcm" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in - throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce cs)); + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); + + bm "aes-128-gcm-unsafe" (fun name -> + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); bm "aes-128-ghash" (fun name -> let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 12 in - throughput name (fun cs -> AES.GCM.authenticate_encrypt ~key ~nonce ~adata:cs "")); + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.authenticate_encrypt_into ~key ~nonce ~adata:cs "" ~src_off:0 dst ~dst_off:0 ~tag_off:0 0)); + + bm "aes-128-ghash-unsafe" (fun name -> + let key = AES.GCM.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 12 in + throughput_into ~add:AES.GCM.tag_size name + (fun dst cs -> AES.GCM.unsafe_authenticate_encrypt_into ~key ~nonce ~adata:cs "" ~src_off:0 dst ~dst_off:0 ~tag_off:0 0)); bm "aes-128-ccm" (fun name -> let key = AES.CCM16.of_secret (Mirage_crypto_rng.generate 16) and nonce = Mirage_crypto_rng.generate 10 in - throughput name (fun cs -> AES.CCM16.authenticate_encrypt ~key ~nonce cs)); - - bm "aes-192-ecb" (fun name -> - let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 24) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + throughput_into ~add:AES.CCM16.tag_size name + (fun dst cs -> AES.CCM16.authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); - bm "aes-256-ecb" (fun name -> - let key = AES.ECB.of_secret (Mirage_crypto_rng.generate 32) in - throughput name (fun cs -> AES.ECB.encrypt ~key cs)) ; + bm "aes-128-ccm-unsafe" (fun name -> + let key = AES.CCM16.of_secret (Mirage_crypto_rng.generate 16) + and nonce = Mirage_crypto_rng.generate 10 in + throughput_into ~add:AES.CCM16.tag_size name + (fun dst cs -> AES.CCM16.unsafe_authenticate_encrypt_into ~key ~nonce cs ~src_off:0 dst ~dst_off:0 ~tag_off:(String.length cs) (String.length cs))); bm "d3des-ecb" (fun name -> let key = DES.ECB.of_secret (Mirage_crypto_rng.generate 24) in - throughput name (fun cs -> DES.ECB.encrypt ~key cs)) ; + throughput_into name (fun dst cs -> DES.ECB.encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; + + bm "d3des-ecb-unsafe" (fun name -> + let key = DES.ECB.of_secret (Mirage_crypto_rng.generate 24) in + throughput_into name (fun dst cs -> DES.ECB.unsafe_encrypt_into ~key cs ~src_off:0 dst ~dst_off:0 (String.length cs))) ; bm "fortuna" (fun name -> let open Mirage_crypto_rng.Fortuna in diff --git a/src/aead.ml b/src/aead.ml index a03214e1..30b716a1 100644 --- a/src/aead.ml +++ b/src/aead.ml @@ -10,4 +10,16 @@ module type AEAD = sig string -> string * string val authenticate_decrypt_tag : key:key -> nonce:string -> ?adata:string -> tag:string -> string -> string option + val authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + val authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + val unsafe_authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + val unsafe_authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool end diff --git a/src/ccm.ml b/src/ccm.ml index 8b68b7f3..ecee28ec 100644 --- a/src/ccm.ml +++ b/src/ccm.ml @@ -10,7 +10,7 @@ let encode_len buf ~off size value = | 0 -> Bytes.set_uint8 buf off num | m -> Bytes.set_uint8 buf (off + m) (num land 0xff); - ass (num lsr 8) (pred m) + (ass [@tailcall]) (num lsr 8) (pred m) in ass value (pred size) @@ -74,10 +74,8 @@ let prepare_header nonce adata plen tlen = type mode = Encrypt | Decrypt -let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = - let datalen = String.length data in - let cbcheader = prepare_header nonce adata datalen maclen in - let dst = Bytes.create datalen in +let crypto_core_into ~cipher ~mode ~key ~nonce ~adata src ~src_off dst ~dst_off len = + let cbcheader = prepare_header nonce adata len block_size in let small_q = 15 - String.length nonce in let ctr_flag_val = flags 0 0 (small_q - 1) in @@ -93,65 +91,62 @@ let crypto_core ~cipher ~mode ~key ~nonce ~maclen ~adata data = cipher ~key (Bytes.unsafe_to_string block) ~src_off:dst_off block ~dst_off in - let cbcprep = + let iv = let rec doit iv iv_off block block_off = match Bytes.length block - block_off with | 0 -> Bytes.sub iv iv_off block_size | _ -> cbc (Bytes.unsafe_to_string iv) iv_off block block_off; - doit block block_off block (block_off + block_size) + (doit [@tailcall]) block block_off block (block_off + block_size) in doit (Bytes.make block_size '\x00') 0 cbcheader 0 in - let rec loop iv ctr src src_off dst dst_off= + let rec loop ctr src src_off dst dst_off len = let cbcblock, cbc_off = match mode with | Encrypt -> src, src_off | Decrypt -> Bytes.unsafe_to_string dst, dst_off in - match String.length src - src_off with - | 0 -> iv - | x when x < block_size -> + if len = 0 then + () + else if len < block_size then begin let buf = Bytes.make block_size '\x00' in - Bytes.unsafe_blit dst dst_off buf 0 x; + Bytes.unsafe_blit dst dst_off buf 0 len ; ctrblock ctr buf ; - Bytes.unsafe_blit buf 0 dst dst_off x ; - unsafe_xor_into src ~src_off dst ~dst_off x ; - Bytes.unsafe_blit_string cbcblock cbc_off buf 0 x; - Bytes.unsafe_fill buf x (block_size - x) '\x00'; - cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 ; - iv - | _ -> + Bytes.unsafe_blit buf 0 dst dst_off len ; + unsafe_xor_into src ~src_off dst ~dst_off len ; + Bytes.unsafe_blit_string cbcblock cbc_off buf 0 len ; + Bytes.unsafe_fill buf len (block_size - len) '\x00'; + cbc (Bytes.unsafe_to_string buf) cbc_off iv 0 + end else begin ctrblock ctr dst ; unsafe_xor_into src ~src_off dst ~dst_off block_size ; cbc cbcblock cbc_off iv 0 ; - loop iv (succ ctr) src (src_off + block_size) dst (dst_off + block_size) + (loop [@tailcall]) (succ ctr) src (src_off + block_size) dst (dst_off + block_size) (len - block_size) + end in - let last = loop cbcprep 1 data 0 dst 0 in - let t = Bytes.sub last 0 maclen in - (dst, t) + loop 1 src src_off dst dst_off len; + iv + +let crypto_core ~cipher ~mode ~key ~nonce ~adata data = + let datalen = String.length data in + let dst = Bytes.create datalen in + let t = crypto_core_into ~cipher ~mode ~key ~nonce ~adata data ~src_off:0 dst ~dst_off:0 datalen in + dst, t let crypto_t t nonce cipher key = let ctr = gen_ctr nonce 0 in cipher ~key (Bytes.unsafe_to_string ctr) ~src_off:0 ctr ~dst_off:0 ; unsafe_xor_into (Bytes.unsafe_to_string ctr) ~src_off:0 t ~dst_off:0 (Bytes.length t) -let valid_nonce nonce = - let nsize = String.length nonce in - if nsize < 7 || nsize > 13 then - invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize - -let generation_encryption ~cipher ~key ~nonce ~maclen ~adata data = - valid_nonce nonce; - let cdata, t = crypto_core ~cipher ~mode:Encrypt ~key ~nonce ~maclen ~adata data in +let unsafe_generation_encryption_into ~cipher ~key ~nonce ~adata src ~src_off dst ~dst_off ~tag_off len = + let t = crypto_core_into ~cipher ~mode:Encrypt ~key ~nonce ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; - Bytes.unsafe_to_string cdata, Bytes.unsafe_to_string t + Bytes.unsafe_blit t 0 dst tag_off block_size -let decryption_verification ~cipher ~key ~nonce ~maclen ~adata ~tag data = - valid_nonce nonce; - let cdata, t = crypto_core ~cipher ~mode:Decrypt ~key ~nonce ~maclen ~adata data in +let unsafe_decryption_verification_into ~cipher ~key ~nonce ~adata src ~src_off ~tag_off dst ~dst_off len = + let tag = String.sub src tag_off block_size in + let t = crypto_core_into ~cipher ~mode:Decrypt ~key ~nonce ~adata src ~src_off dst ~dst_off len in crypto_t t nonce cipher key ; - match Eqaf.equal tag (Bytes.unsafe_to_string t) with - | true -> Some (Bytes.unsafe_to_string cdata) - | false -> None + Eqaf.equal tag (Bytes.unsafe_to_string t) diff --git a/src/chacha20.ml b/src/chacha20.ml index f0d97840..2c70251d 100644 --- a/src/chacha20.ml +++ b/src/chacha20.ml @@ -42,77 +42,121 @@ let init ctr ~key ~nonce = Bytes.unsafe_blit_string nonce 0 state nonce_off (String.length nonce) ; state, inc -let crypt ~key ~nonce ?(ctr = 0L) data = +let crypt_into ~key ~nonce ~ctr src ~src_off dst ~dst_off len = let state, inc = init ctr ~key ~nonce in - let l = String.length data in - let block_count = l // block in + let block_count = len // block in let last_len = - let last = l mod block in + let last = len mod block in if last = 0 then block else last in - let res = Bytes.create l in let rec loop i = function | 0 -> () | 1 -> if last_len = block then begin - chacha20_block state i res ; - Native.xor_into_bytes data i res i block + chacha20_block state (dst_off + i) dst ; + Native.xor_into_bytes src (src_off + i) dst (dst_off + i) block end else begin let buf = Bytes.create block in chacha20_block state 0 buf ; - Native.xor_into_bytes data i buf 0 last_len ; - Bytes.unsafe_blit buf 0 res i last_len + Native.xor_into_bytes src (src_off + i) buf 0 last_len ; + Bytes.unsafe_blit buf 0 dst (dst_off + i) last_len end | n -> - chacha20_block state i res ; - Native.xor_into_bytes data i res i block ; + chacha20_block state (dst_off + i) dst ; + Native.xor_into_bytes src (src_off + i) dst (dst_off + i) block ; inc state; - loop (i + block) (n - 1) + (loop [@tailcall]) (i + block) (n - 1) in - loop 0 block_count ; + loop 0 block_count + +let crypt ~key ~nonce ?(ctr = 0L) data = + let l = String.length data in + let res = Bytes.create l in + crypt_into ~key ~nonce ~ctr data ~src_off:0 res ~dst_off:0 l; Bytes.unsafe_to_string res module P = Poly1305.It +let tag_size = P.mac_size + let generate_poly1305_key ~key ~nonce = crypt ~key ~nonce (String.make 32 '\000') -let mac ~key ~adata ciphertext = - let pad16 b = - let len = String.length b mod 16 in +let mac_into ~key ~adata src ~src_off len dst ~dst_off = + let pad16 l = + let len = l mod 16 in if len = 0 then "" else String.make (16 - len) '\000' - and len = + and len_buf = let data = Bytes.create 16 in Bytes.set_int64_le data 0 (Int64.of_int (String.length adata)); - Bytes.set_int64_le data 8 (Int64.of_int (String.length ciphertext)); + Bytes.set_int64_le data 8 (Int64.of_int len); Bytes.unsafe_to_string data in - P.macl ~key [ adata ; pad16 adata ; ciphertext ; pad16 ciphertext ; len ] + let p1 = pad16 (String.length adata) and p2 = pad16 len in + P.unsafe_mac_into ~key [ adata, 0, String.length adata ; + p1, 0, String.length p1 ; + src, src_off, len ; + p2, 0, String.length p2 ; + len_buf, 0, String.length len_buf ] + dst ~dst_off -let authenticate_encrypt_tag ~key ~nonce ?(adata = "") data = +let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = let poly1305_key = generate_poly1305_key ~key ~nonce in - let ciphertext = crypt ~key ~nonce ~ctr:1L data in - let mac = mac ~key:poly1305_key ~adata ciphertext in - ciphertext, mac + crypt_into ~key ~nonce ~ctr:1L src ~src_off dst ~dst_off len; + mac_into ~key:poly1305_key ~adata (Bytes.unsafe_to_string dst) ~src_off:dst_off len dst ~dst_off:tag_off + +let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + if String.length src - src_off < len then + invalid_arg "Chacha20: src length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "Chacha20: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if Bytes.length dst - tag_off < tag_size then + invalid_arg "Chacha20: dst length %u - tag_off %u < tag_size %u" + (Bytes.length dst) tag_off tag_size; + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - cdata ^ ctag + let l = String.length data in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata data ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst -let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag data = +let authenticate_encrypt_tag ~key ~nonce ?adata data = + let r = authenticate_encrypt ~key ~nonce ?adata data in + String.sub r 0 (String.length data), String.sub r (String.length data) tag_size + +let unsafe_authenticate_decrypt_into ~key ~nonce ?(adata = "") src ~src_off ~tag_off dst ~dst_off len = let poly1305_key = generate_poly1305_key ~key ~nonce in - let ctag = mac ~key:poly1305_key ~adata data in - let plain = crypt ~key ~nonce ~ctr:1L data in - if Eqaf.equal tag ctag then Some plain else None + let ctag = Bytes.create tag_size in + mac_into ~key:poly1305_key ~adata src ~src_off len ctag ~dst_off:0; + crypt_into ~key ~nonce ~ctr:1L src ~src_off dst ~dst_off len; + Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) + +let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + if String.length src - src_off < len then + invalid_arg "Chacha20: src length %u - src_off %u < len %u" + (String.length src) src_off len; + if Bytes.length dst - dst_off < len then + invalid_arg "Chacha20: dst length %u - dst_off %u < len %u" + (Bytes.length dst) dst_off len; + if String.length src - tag_off < tag_size then + invalid_arg "Chacha20: src length %u - tag_off %u < tag_size %u" + (String.length src) tag_off tag_size; + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata data = - if String.length data < P.mac_size then + if String.length data < tag_size then None else - let cipher, tag = - let p = String.length data - P.mac_size in - String.sub data 0 p, String.sub data p P.mac_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher + let l = String.length data - tag_size in + let r = Bytes.create l in + if unsafe_authenticate_decrypt_into ~key ~nonce ?adata data ~src_off:0 ~tag_off:l r ~dst_off:0 l then + Some (Bytes.unsafe_to_string r) + else + None -let tag_size = P.mac_size +let authenticate_decrypt_tag ~key ~nonce ?adata ~tag data = + let cdata = data ^ tag in + authenticate_decrypt ~key ~nonce ?adata cdata diff --git a/src/cipher_block.ml b/src/cipher_block.ml index 3dfa1fcb..d430492f 100644 --- a/src/cipher_block.ml +++ b/src/cipher_block.ml @@ -28,6 +28,10 @@ module Block = struct val block_size : int val encrypt : key:key -> string -> string val decrypt : key:key -> string -> string + val encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit end module type CBC = sig @@ -40,7 +44,19 @@ module Block = struct val encrypt : key:key -> iv:string -> string -> string val decrypt : key:key -> iv:string -> string -> string - val next_iv : iv:string -> string -> string + val next_iv : ?off:int -> string -> iv:string -> string + + val encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + + val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_encrypt_into_inplace : key:key -> iv:string -> + bytes -> dst_off:int -> int -> unit end module type CTR = sig @@ -48,18 +64,29 @@ module Block = struct type key val of_secret : string -> key - type ctr - val key_sizes : int array val block_size : int + type ctr + val add_ctr : ctr -> int64 -> ctr + val next_ctr : ?off:int -> string -> ctr:ctr -> ctr + val ctr_of_octets : string -> ctr + val stream : key:key -> ctr:ctr -> int -> string val encrypt : key:key -> ctr:ctr -> string -> string val decrypt : key:key -> ctr:ctr -> string -> string - val add_ctr : ctr -> int64 -> ctr - val next_ctr : ctr:ctr -> string -> ctr - val ctr_of_octets : string -> ctr + val stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + val encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + + val unsafe_stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + val unsafe_encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + val unsafe_decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit end module type GCM = sig @@ -83,7 +110,7 @@ module Counters = struct val size : int val add : ctr -> int64 -> ctr val of_octets : string -> ctr - val unsafe_count_into : ctr -> bytes -> blocks:int -> unit + val unsafe_count_into : ctr -> bytes -> off:int -> blocks:int -> unit end module C64be = struct @@ -91,10 +118,10 @@ module Counters = struct let size = 8 let of_octets cs = String.get_int64_be cs 0 let add = Int64.add - let unsafe_count_into t buf ~blocks = - let tmp = Bytes.create 8 in - Bytes.set_int64_be tmp 0 t; - Native.count8be tmp buf ~blocks + let unsafe_count_into t buf ~off ~blocks = + let ctr = Bytes.create 8 in + Bytes.set_int64_be ctr 0 t; + Native.count8be ~ctr buf ~off ~blocks end module C128be = struct @@ -107,10 +134,10 @@ module Counters = struct let w0' = Int64.add w0 n in let flip = if Int64.logxor w0 w0' < 0L then w0' > w0 else w0' < w0 in ((if flip then Int64.succ w1 else w1), w0') - let unsafe_count_into (w1, w0) buf ~blocks = - let tmp = Bytes.create 16 in - Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; - Native.count16be tmp buf ~blocks + let unsafe_count_into (w1, w0) buf ~off ~blocks = + let ctr = Bytes.create 16 in + Bytes.set_int64_be ctr 0 w1; Bytes.set_int64_be ctr 8 w0; + Native.count16be ~ctr buf ~off ~blocks end module C128be32 = struct @@ -118,13 +145,22 @@ module Counters = struct let add (w1, w0) n = let hi = 0xffffffff00000000L and lo = 0x00000000ffffffffL in (w1, Int64.(logor (logand hi w0) (add n w0 |> logand lo))) - let unsafe_count_into (w1, w0) buf ~blocks = - let tmp = Bytes.create 16 in - Bytes.set_int64_be tmp 0 w1; Bytes.set_int64_be tmp 8 w0; - Native.count16be4 tmp buf ~blocks + let unsafe_count_into (w1, w0) buf ~off ~blocks = + let ctr = Bytes.create 16 in + Bytes.set_int64_be ctr 0 w1; Bytes.set_int64_be ctr 8 w0; + Native.count16be4 ~ctr buf ~off ~blocks end end +let check_offset ~tag ~buf ~off ~len actual_len = + if off < 0 then + invalid_arg "%s: %s off %u < 0" + tag buf off; + if actual_len - off < len then + invalid_arg "%s: %s length %u - off %u < len %u" + tag buf actual_len off len +[@@inline] + module Modes = struct module ECB_of (Core : Block.Core) : Block.ECB = struct @@ -134,17 +170,39 @@ module Modes = struct let of_secret = Core.of_secret - let (encrypt, decrypt) = - let ecb xform key src = - let n = String.length src in - if n mod block_size <> 0 then invalid_arg "ECB: length %u" n; - let dst = Bytes.create n in - xform ~key ~blocks:(n / block_size) src 0 dst 0 ; - Bytes.unsafe_to_string dst - in - (fun ~key:(key, _) src -> ecb Core.encrypt key src), - (fun ~key:(_, key) src -> ecb Core.decrypt key src) + let unsafe_ecb xform key src src_off dst dst_off len = + xform ~key ~blocks:(len / block_size) src src_off dst dst_off + + let ecb xform key src src_off dst dst_off len = + if len mod block_size <> 0 then + invalid_arg "ECB: length %u not of block size" len; + check_offset ~tag:"ECB" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"ECB" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + unsafe_ecb xform key src src_off dst dst_off len + + let encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = + ecb Core.encrypt key src src_off dst dst_off len + let unsafe_encrypt_into ~key:(key, _) src ~src_off dst ~dst_off len = + unsafe_ecb Core.encrypt key src src_off dst dst_off len + + let decrypt_into ~key:(_, key) src ~src_off dst ~dst_off len = + ecb Core.decrypt key src src_off dst dst_off len + + let unsafe_decrypt_into ~key:(_, key) src ~src_off dst ~dst_off len = + unsafe_ecb Core.decrypt key src src_off dst dst_off len + + let encrypt ~key src = + let len = String.length src in + let dst = Bytes.create len in + encrypt_into ~key src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst + + let decrypt ~key src = + let len = String.length src in + let dst = Bytes.create len in + decrypt_into ~key src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst end module CBC_of (Core : Block.Core) : Block.CBC = struct @@ -156,40 +214,64 @@ module Modes = struct let of_secret = Core.of_secret - let bounds_check ~iv cs = - if String.length iv <> block then invalid_arg "CBC: IV length %u" (String.length iv); - if String.length cs mod block <> 0 then - invalid_arg "CBC: argument length %u" (String.length cs) - - let next_iv ~iv cs = - bounds_check ~iv cs ; - if String.length cs > 0 then + let check_block_size ~iv len = + if String.length iv <> block then + invalid_arg "CBC: IV length %u not of block size" (String.length iv); + if len mod block <> 0 then + invalid_arg "CBC: argument length %u not of block size" + len + [@@inline] + + let next_iv ?(off = 0) cs ~iv = + check_block_size ~iv (String.length cs - off) ; + if String.length cs > off then String.sub cs (String.length cs - block_size) block_size else iv - let encrypt ~key:(key, _) ~iv src = - bounds_check ~iv src ; - let dst = Bytes.of_string src in + let unsafe_encrypt_into_inplace ~key:(key, _) ~iv dst ~dst_off len = let rec loop iv iv_i dst_i = function - 0 -> () - | b -> Native.xor_into_bytes iv iv_i dst dst_i block ; - Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; - loop (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) + | 0 -> () + | b -> + Native.xor_into_bytes iv iv_i dst dst_i block ; + Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string dst) dst_i dst dst_i ; + (loop [@tailcall]) (Bytes.unsafe_to_string dst) dst_i (dst_i + block) (b - 1) in - loop iv 0 0 (Bytes.length dst / block) ; + loop iv 0 dst_off (len / block) + + let unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len = + Bytes.unsafe_blit_string src src_off dst dst_off len; + unsafe_encrypt_into_inplace ~key ~iv dst ~dst_off len + + let encrypt_into ~key ~iv src ~src_off dst ~dst_off len = + check_block_size ~iv len; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + unsafe_encrypt_into ~key ~iv src ~src_off dst ~dst_off len + + let encrypt ~key ~iv src = + let dst = Bytes.create (String.length src) in + encrypt_into ~key ~iv src ~src_off:0 dst ~dst_off:0 (String.length src); Bytes.unsafe_to_string dst - let decrypt ~key:(_, key) ~iv src = - bounds_check ~iv src ; - let msg = Bytes.create (String.length src) - and b = String.length src / block in + let unsafe_decrypt_into ~key:(_, key) ~iv src ~src_off dst ~dst_off len = + let b = len / block in if b > 0 then begin - Core.decrypt ~key ~blocks:b src 0 msg 0 ; - Native.xor_into_bytes iv 0 msg 0 block ; - Native.xor_into_bytes src 0 msg block ((b - 1) * block) ; - end ; + Core.decrypt ~key ~blocks:b src src_off dst dst_off ; + Native.xor_into_bytes iv 0 dst dst_off block ; + Native.xor_into_bytes src src_off dst (dst_off + block) ((b - 1) * block) ; + end + + let decrypt_into ~key ~iv src ~src_off dst ~dst_off len = + check_block_size ~iv len; + check_offset ~tag:"CBC" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CBC" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + unsafe_decrypt_into ~key ~iv src ~src_off dst ~dst_off len + + let decrypt ~key ~iv src = + let len = String.length src in + let msg = Bytes.create len in + decrypt_into ~key ~iv src ~src_off:0 msg ~dst_off:0 len; Bytes.unsafe_to_string msg - end module CTR_of (Core : Block.Core) (Ctr : Counters.S) : @@ -204,30 +286,52 @@ module Modes = struct let (key_sizes, block_size) = Core.(key, block) let of_secret = Core.e_of_secret - let stream ~key ~ctr n = - let blocks = imax 0 n / block_size in - let buf = Bytes.create n in - Ctr.unsafe_count_into ctr ~blocks buf ; - Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) 0 buf 0 ; - let slack = imax 0 n mod block_size in + let unsafe_stream_into ~key ~ctr buf ~off len = + let blocks = imax 0 len / block_size in + Ctr.unsafe_count_into ctr buf ~off ~blocks ; + Core.encrypt ~key ~blocks (Bytes.unsafe_to_string buf) off buf off ; + let slack = imax 0 len mod block_size in if slack <> 0 then begin let buf' = Bytes.create block_size in let ctr = Ctr.add ctr (Int64.of_int blocks) in - Ctr.unsafe_count_into ctr ~blocks:1 buf' ; + Ctr.unsafe_count_into ctr buf' ~off:0 ~blocks:1 ; Core.encrypt ~key ~blocks:1 (Bytes.unsafe_to_string buf') 0 buf' 0 ; - Bytes.unsafe_blit buf' 0 buf (blocks * block_size) slack - end; + Bytes.unsafe_blit buf' 0 buf (off + blocks * block_size) slack + end + + let stream_into ~key ~ctr buf ~off len = + check_offset ~tag:"CTR" ~buf:"buf" ~off ~len (Bytes.length buf); + unsafe_stream_into ~key ~ctr buf ~off len + + let stream ~key ~ctr n = + let buf = Bytes.create n in + unsafe_stream_into ~key ~ctr buf ~off:0 n; Bytes.unsafe_to_string buf + let unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = + unsafe_stream_into ~key ~ctr dst ~off:dst_off len; + Uncommon.unsafe_xor_into src ~src_off dst ~dst_off len + + let encrypt_into ~key ~ctr src ~src_off dst ~dst_off len = + check_offset ~tag:"CTR" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CTR" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + unsafe_encrypt_into ~key ~ctr src ~src_off dst ~dst_off len + let encrypt ~key ~ctr src = - let res = Bytes.unsafe_of_string (stream ~key ~ctr (String.length src)) in - Native.xor_into_bytes src 0 res 0 (String.length src) ; - Bytes.unsafe_to_string res + let len = String.length src in + let dst = Bytes.create len in + encrypt_into ~key ~ctr src ~src_off:0 dst ~dst_off:0 len; + Bytes.unsafe_to_string dst let decrypt = encrypt + let decrypt_into = encrypt_into + + let unsafe_decrypt_into = unsafe_encrypt_into + let add_ctr = Ctr.add - let next_ctr ~ctr msg = add_ctr ctr (Int64.of_int @@ String.length msg // block_size) + let next_ctr ?(off = 0) msg ~ctr = + add_ctr ctr (Int64.of_int @@ (String.length msg - off) // block_size) let ctr_of_octets = Ctr.of_octets end @@ -235,6 +339,7 @@ module Modes = struct type key val derive : string -> key val digesti : key:key -> (string Uncommon.iter) -> string + val digesti_off_len : key:key -> (string * int * int) Uncommon.iter -> string val tagsize : int end = struct type key = string @@ -245,15 +350,20 @@ module Modes = struct let k = Bytes.create keysize in Native.GHASH.keyinit cs k; Bytes.unsafe_to_string k + let digesti_off_len ~key i = + let res = Bytes.make tagsize '\x00' in + i (fun (cs, off, len) -> Native.GHASH.ghash key res cs off len); + Bytes.unsafe_to_string res let digesti ~key i = let res = Bytes.make tagsize '\x00' in - i (fun cs -> Native.GHASH.ghash key res cs (String.length cs)); + i (fun cs -> Native.GHASH.ghash key res cs 0 (String.length cs)); Bytes.unsafe_to_string res + end module GCM_of (C : Block.Core) : Block.GCM = struct - let _ = assert (C.block = 16) + assert (C.block = 16) module CTR = CTR_of (C) (Counters.C128be32) type key = { key : C.ekey ; hkey : GHASH.key } @@ -285,43 +395,69 @@ module Modes = struct CTR.ctr_of_octets @@ GHASH.digesti ~key:hkey @@ iter2 nonce (pack64s 0L (bits64 nonce)) - let tag ~key ~hkey ~ctr ?(adata = "") cdata = - CTR.encrypt ~key ~ctr @@ - GHASH.digesti ~key:hkey @@ - iter3 adata cdata (pack64s (bits64 adata) (bits64 cdata)) + let unsafe_tag_into ~key ~hkey ~ctr ?(adata = "") cdata ~off ~len dst ~tag_off = + CTR.unsafe_encrypt_into ~key ~ctr + (GHASH.digesti_off_len ~key:hkey + (iter3 (adata, 0, String.length adata) (cdata, off, len) + (pack64s (bits64 adata) (Int64.of_int (len * 8)), 0, 16))) + ~src_off:0 dst ~dst_off:tag_off tag_size + + let unsafe_authenticate_encrypt_into ~key:{ key; hkey } ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + let ctr = counter ~hkey nonce in + CTR.(unsafe_encrypt_into ~key ~ctr:(add_ctr ctr 1L) src ~src_off dst ~dst_off len); + unsafe_tag_into ~key ~hkey ~ctr ?adata (Bytes.unsafe_to_string dst) ~off:dst_off ~len dst ~tag_off - let authenticate_encrypt_tag ~key:{ key; hkey } ~nonce ?adata data = - let ctr = counter ~hkey nonce in - let cdata = CTR.(encrypt ~key ~ctr:(add_ctr ctr 1L) data) in - let ctag = tag ~key ~hkey ~ctr ?adata cdata in - cdata, ctag + let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"GCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata data = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata data in - cdata ^ ctag + let l = String.length data in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata data ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst + + let authenticate_encrypt_tag ~key ~nonce ?adata data = + let r = authenticate_encrypt ~key ~nonce ?adata data in + String.sub r 0 (String.length data), + String.sub r (String.length data) tag_size + + let unsafe_authenticate_decrypt_into ~key:{ key; hkey } ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + let ctr = counter ~hkey nonce in + CTR.(unsafe_encrypt_into ~key ~ctr:(add_ctr ctr 1L) src ~src_off dst ~dst_off len); + let ctag = Bytes.create tag_size in + unsafe_tag_into ~key ~hkey ~ctr ?adata src ~off:src_off ~len ctag ~tag_off:0; + Eqaf.equal (String.sub src tag_off tag_size) (Bytes.unsafe_to_string ctag) - let authenticate_decrypt_tag ~key:{ key; hkey } ~nonce ?adata ~tag:tag_data cipher = - let ctr = counter ~hkey nonce in - let data = CTR.(encrypt ~key ~ctr:(add_ctr ctr 1L) cipher) in - let ctag = tag ~key ~hkey ~ctr ?adata cipher in - if Eqaf.equal tag_data ctag then Some data else None + let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + check_offset ~tag:"GCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"GCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"GCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata cdata = if String.length cdata < tag_size then None else - let cipher, tag = - String.sub cdata 0 (String.length cdata - tag_size), - String.sub cdata (String.length cdata - tag_size) tag_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag cipher + let l = String.length cdata - tag_size in + let data = Bytes.create l in + if unsafe_authenticate_decrypt_into ~key ~nonce ?adata cdata ~src_off:0 ~tag_off:l data ~dst_off:0 l then + Some (Bytes.unsafe_to_string data) + else + None + + let authenticate_decrypt_tag ~key ~nonce ?adata ~tag:tag_data cipher = + let cdata = cipher ^ tag_data in + authenticate_decrypt ~key ~nonce ?adata cdata end module CCM16_of (C : Block.Core) : Block.CCM16 = struct - let _ = assert (C.block = 16) + assert (C.block = 16) - let tag_size = 16 + let tag_size = C.block type key = C.ekey @@ -330,29 +466,58 @@ module Modes = struct let (key_sizes, block_size) = C.(key, block) let cipher ~key src ~src_off dst ~dst_off = - if String.length src - src_off < block_size || Bytes.length dst - dst_off < block_size then - invalid_arg "src len %u, dst len %u" (String.length src - src_off) (Bytes.length dst - dst_off); C.encrypt ~key ~blocks:1 src src_off dst dst_off - let authenticate_encrypt_tag ~key ~nonce ?(adata = "") cs = - Ccm.generation_encryption ~cipher ~key ~nonce ~maclen:tag_size ~adata cs + let unsafe_authenticate_encrypt_into ~key ~nonce ?(adata = "") src ~src_off dst ~dst_off ~tag_off len = + Ccm.unsafe_generation_encryption_into ~cipher ~key ~nonce ~adata + src ~src_off dst ~dst_off ~tag_off len + + let valid_nonce nonce = + let nsize = String.length nonce in + if nsize < 7 || nsize > 13 then + invalid_arg "CCM: nonce length not between 7 and 13: %u" nsize + + let authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len = + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + check_offset ~tag:"CCM" ~buf:"dst tag" ~off:tag_off ~len:tag_size (Bytes.length dst); + valid_nonce nonce; + unsafe_authenticate_encrypt_into ~key ~nonce ?adata src ~src_off dst ~dst_off ~tag_off len let authenticate_encrypt ~key ~nonce ?adata cs = - let cdata, ctag = authenticate_encrypt_tag ~key ~nonce ?adata cs in - cdata ^ ctag + valid_nonce nonce; + let l = String.length cs in + let dst = Bytes.create (l + tag_size) in + unsafe_authenticate_encrypt_into ~key ~nonce ?adata cs ~src_off:0 dst ~dst_off:0 ~tag_off:l l; + Bytes.unsafe_to_string dst + + let authenticate_encrypt_tag ~key ~nonce ?adata cs = + let res = authenticate_encrypt ~key ~nonce ?adata cs in + String.sub res 0 (String.length cs), String.sub res (String.length cs) tag_size + + let unsafe_authenticate_decrypt_into ~key ~nonce ?(adata = "") src ~src_off ~tag_off dst ~dst_off len = + Ccm.unsafe_decryption_verification_into ~cipher ~key ~nonce ~adata src ~src_off ~tag_off dst ~dst_off len - let authenticate_decrypt_tag ~key ~nonce ?(adata = "") ~tag cs = - Ccm.decryption_verification ~cipher ~key ~nonce ~maclen:tag_size ~adata ~tag cs + let authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len = + check_offset ~tag:"CCM" ~buf:"src" ~off:src_off ~len (String.length src); + check_offset ~tag:"CCM" ~buf:"src tag" ~off:tag_off ~len:tag_size (String.length src); + check_offset ~tag:"CCM" ~buf:"dst" ~off:dst_off ~len (Bytes.length dst); + valid_nonce nonce; + unsafe_authenticate_decrypt_into ~key ~nonce ?adata src ~src_off ~tag_off dst ~dst_off len let authenticate_decrypt ~key ~nonce ?adata data = if String.length data < tag_size then None else - let data, tag = - String.sub data 0 (String.length data - tag_size), - String.sub data (String.length data - tag_size) tag_size - in - authenticate_decrypt_tag ~key ~nonce ?adata ~tag data + let dlen = String.length data - tag_size in + let dst = Bytes.create dlen in + if authenticate_decrypt_into ~key ~nonce ?adata data ~src_off:0 ~tag_off:dlen dst ~dst_off:0 dlen then + Some (Bytes.unsafe_to_string dst) + else + None + + let authenticate_decrypt_tag ~key ~nonce ?adata ~tag cs = + authenticate_decrypt ~key ~nonce ?adata (cs ^ tag) end end diff --git a/src/cipher_stream.ml b/src/cipher_stream.ml index 67ee0a63..69bbae3f 100644 --- a/src/cipher_stream.ml +++ b/src/cipher_stream.ml @@ -26,7 +26,7 @@ module ARC4 = struct let j = (j + si + x) land 0xff in let sj = s.(j) in s.(i) <- sj ; s.(j) <- si ; - loop j (succ i) + (loop [@tailcall]) j (succ i) in ( loop 0 0 ; (0, 0, s) ) @@ -44,7 +44,7 @@ module ARC4 = struct s.(i) <- sj ; s.(j) <- si ; let k = s.((si + sj) land 0xff) in Bytes.set_uint8 res n (k lxor String.get_uint8 buf n); - mix i j (succ n) + (mix [@tailcall]) i j (succ n) in let key' = mix i j 0 in { key = key' ; message = Bytes.unsafe_to_string res } diff --git a/src/mirage_crypto.mli b/src/mirage_crypto.mli index d33b8420..d85797e6 100644 --- a/src/mirage_crypto.mli +++ b/src/mirage_crypto.mli @@ -74,8 +74,13 @@ module Poly1305 : sig (** [maci ~key iter] is the all-in-one mac computation: [get (feedi (empty ~key) iter)]. *) - val macl : key:string -> string list -> string - (** [macl ~key datas] computes the [mac] of [datas]. *) + val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + (** [mac_into ~key datas dst dst_off] computes the [mac] of [datas]. *) + + (**/**) + val unsafe_mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + (** [unsafe_mac_into ~key datas dst dst_off] is {!mac_into} without bounds checks. *) + (**/**) end (** {1 Symmetric-key cryptography} *) @@ -141,6 +146,65 @@ module type AEAD = sig returned. @raise Invalid_argument if [nonce] is not of the right size. *) + + (** {1 Authenticated encryption and decryption into existing buffers} *) + + val authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + (** [authenticate_encrypt_into ~key ~nonce ~adata msg ~src_off dst ~dst_off ~tag_off len] + encrypts [len] bytes of [msg] starting at [src_off] with [key] and [nonce]. The output + is put into [dst] at [dst_off], the tag into [dst] at [tag_off]. + + @raise Invalid_argument if [nonce] is not of the right size. + @raise Invalid_argument if [String.length msg - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [Bytes.length dst - tag_off < tag_size]. + *) + + val authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + (** [authenticate_decrypt_into ~key ~nonce ~adata msg ~src_off ~tag_off dst ~dst_off len] + computes the authentication tag using [key], [nonce], and [adata], and + decrypts the [len] bytes encrypted data from [msg] starting at [src_off] into [dst] + starting at [dst_off]. If the authentication tags match, [true] is + returned, and the decrypted data is in [dst]. + + @raise Invalid_argument if [nonce] is not of the right size. + @raise Invalid_argument if [String.length msg - src_off < len]. + @raise Invalid_argument if [Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [String.length msg - tag_off < tag_size]. *) + + (**/**) + val unsafe_authenticate_encrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> bytes -> dst_off:int -> + tag_off:int -> int -> unit + (** [unsafe_authenticate_encrypt_into] is {!authenticate_encrypt_into}, but + without bounds checks. + + @raise Invalid_argument if [nonce] is not of the right size. + + This may cause memory issues if an invariant is violated: + {ul + {- [String.length msg - src_off >= len].} + {- [Bytes.length dst - dst_off >= len].} + {- [Bytes.length dst - tag_off >= tag_size].}} *) + + val unsafe_authenticate_decrypt_into : key:key -> nonce:string -> + ?adata:string -> string -> src_off:int -> tag_off:int -> bytes -> + dst_off:int -> int -> bool + (** [unsafe_authenticate_decrypt_into] is {!authenticate_decrypt_into}, but + without bounds checks. + + @raise Invalid_argument if [nonce] is not of the right size. + + This may cause memory issues if an invariant is violated: + {ul + {- [String.length msg - src_off >= len].} + {- [Bytes.length dst - dst_off >= len].} + {- [String.length msg - tag_off >= tag_size].}} *) + (**/**) end (** Block ciphers. @@ -157,12 +221,68 @@ module Block : sig module type ECB = sig type key + val of_secret : string -> key + (** Construct the encryption key corresponding to [secret]. + + @raise Invalid_argument if the length of [secret] is not in + {{!key_sizes}[key_sizes]}. *) val key_sizes : int array + (** Key sizes allowed with this cipher. *) + val block_size : int + (** The size of a single block. *) + val encrypt : key:key -> string -> string + (** [encrypt ~key src] encrypts [src] into a freshly allocated buffer of the + same size using [key]. + + @raise Invalid_argument if the length of [src] is not a multiple of + {!block_size}. *) + val decrypt : key:key -> string -> string + (** [decrypt ~key src] decrypts [src] into a freshly allocated buffer of the + same size using [key]. + + @raise Invalid_argument if the length of [src] is not a multiple of + {!block_size}. *) + + val encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key src ~src_off dst dst_off len] encrypts [len] octets + from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) + + val decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [decrypt_into ~key src ~src_off dst dst_off len] decrypts [len] octets + from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) + + (**/**) + val unsafe_encrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: + {ul + {- [len] must be a multiple of {!block_size},} + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) + + val unsafe_decrypt_into : key:key -> string -> src_off:int -> bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into] is {!decrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: + {ul + {- [len] must be a multiple of {!block_size},} + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) + (**/**) end (** {e Cipher-block chaining} mode. *) @@ -195,22 +315,79 @@ module Block : sig @raise Invalid_argument if [iv] is not [block_size], or [msg] is not [k * block_size] long. *) - val next_iv : iv:string -> string -> string - (** [next_iv ~iv ciphertext] is the first [iv] {e following} the + val next_iv : ?off:int -> string -> iv:string -> string + (** [next_iv ~iv ciphertext ~off] is the first [iv] {e following} the encryption that used [iv] to produce [ciphertext]. For protocols which perform inter-message chaining, this is the [iv] for the next message. - It is either [iv], when [len ciphertext = 0], or the last block of - [ciphertext]. Note that + It is either [iv], when [String.length ciphertext - off = 0], or the + last block of [ciphertext]. Note that {[encrypt ~iv msg1 || encrypt ~iv:(next_iv ~iv (encrypt ~iv msg1)) msg2 == encrypt ~iv (msg1 || msg2)]} - @raise Invalid_argument if the length of [iv] is not [block_size], or - the length of [ciphertext] is not [k * block_size] for some [k]. *) - end + @raise Invalid_argument if the length of [iv] is not [block_size]. + @raise Invalid_argument if the length of [ciphertext] is not a multiple + of [block_size]. *) + + val encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key ~iv src ~src_off dst dst_off len] encrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if the length of [iv] is not {!block_size}. + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) + + val decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [decrypt_into ~key ~iv src ~src_off dst dst_off len] decrypts [len] + octets from [src] starting at [src_off] into [dst] starting at [dst_off]. + + @raise Invalid_argument if the length of [iv] is not {!block_size}. + @raise Invalid_argument if [len] is not a multiple of {!block_size}. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. *) + + (**/**) + val unsafe_encrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. + + This may casue memory issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) + + val unsafe_decrypt_into : key:key -> iv:string -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into] is {!decrypt_into}, but without bounds checks. + + This may casue memory issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) + + val unsafe_encrypt_into_inplace : key:key -> iv:string -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into_inplace] is {!unsafe_encrypt_into}, but assumes + that [dst] already contains the mesage to be encrypted. + + This may casue memory issues if an invariant is violated: + {ul + {- the length of [iv] must be {!block_size},} + {- [len] must be a multiple of {!block_size},} + {- [src_off >= 0 && String.length src - src_off >= len],} + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len].}} *) + (**/**) +end (** {e Counter} mode. *) module type CTR = sig @@ -231,6 +408,27 @@ module Block : sig type ctr + val add_ctr : ctr -> int64 -> ctr + (** [add_ctr ctr n] adds [n] to [ctr]. *) + + val next_ctr : ?off:int -> string -> ctr:ctr -> ctr + (** [next_ctr ~off msg ~ctr] is the state of the counter after encrypting or + decrypting [msg] at offset [off] with the counter [ctr]. + + For protocols which perform inter-message chaining, this is the + counter for the next message. + + It is computed as [C.add ctr (ceil (len msg / block_size))]. Note that + if [len msg1 = k * block_size], + +{[encrypt ~ctr msg1 || encrypt ~ctr:(next_ctr ~ctr msg1) msg2 + == encrypt ~ctr (msg1 || msg2)]} + + *) + + val ctr_of_octets : string -> ctr + (** [ctr_of_octets buf] converts the value of [buf] into a counter. *) + val stream : key:key -> ctr:ctr -> int -> string (** [stream ~key ~ctr n] is the raw keystream. @@ -249,31 +447,51 @@ module Block : sig val encrypt : key:key -> ctr:ctr -> string -> string (** [encrypt ~key ~ctr msg] is - [stream ~key ~ctr ~off (len msg) lxor msg]. *) + [stream ~key ~ctr (len msg) lxor msg]. *) val decrypt : key:key -> ctr:ctr -> string -> string (** [decrypt] is [encrypt]. *) - val add_ctr : ctr -> int64 -> ctr - (** [add_ctr ctr n] adds [n] to [ctr]. *) + val stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + (** [stream_into ~key ~ctr dst ~off len] is the raw key stream put into + [dst] starting at [off]. - val next_ctr : ctr:ctr -> string -> ctr - (** [next_ctr ~ctr msg] is the state of the counter after encrypting or - decrypting [msg] with the counter [ctr]. + @raise Invalid_argument if [Bytes.length dst - off < len]. *) - For protocols which perform inter-message chaining, this is the - counter for the next message. + val encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [encrypt_into ~key ~ctr src ~src_off dst ~dst_off len] produces the + key stream into [dst] at [dst_off], and then xors it with [src] at + [src_off]. - It is computed as [C.add ctr (ceil (len msg / block_size))]. Note that - if [len msg1 = k * block_size], + @raise Invalid_argument if [dst_off < 0 || Bytes.length dst - dst_off < len]. + @raise Invalid_argument if [src_off < 0 || String.length src - src_off < len]. *) -{[encrypt ~ctr msg1 || encrypt ~ctr:(next_ctr ~ctr msg1) msg2 - == encrypt ~ctr (msg1 || msg2)]} + val decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [decrypt_into] is {!encrypt_into}. *) - *) + (**/**) + val unsafe_stream_into : key:key -> ctr:ctr -> bytes -> off:int -> int -> unit + (** [unsafe_stream_into] is {!stream_into}, but without bounds checks. - val ctr_of_octets : string -> ctr - (** [ctr_of_octets buf] converts the value of [buf] into a counter. *) + This may cause memory issues if the invariant is violated: + {ul + {- [off >= 0 && Bytes.length buf - off >= len].}} *) + + val unsafe_encrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_encrypt_into] is {!encrypt_into}, but without bounds checks. + + This may cause memory issues if an invariant is violated: + {ul + {- [dst_off >= 0 && Bytes.length dst - dst_off >= len],} + {- [src_off >= 0 && String.length src - src_off >= len].}} *) + + val unsafe_decrypt_into : key:key -> ctr:ctr -> string -> src_off:int -> + bytes -> dst_off:int -> int -> unit + (** [unsafe_decrypt_into] is {!unsafe_encrypt_into}. *) + (**/**) end (** {e Galois/Counter Mode}. *) diff --git a/src/native.ml b/src/native.ml index 55437a70..f6d59da3 100644 --- a/src/native.ml +++ b/src/native.ml @@ -20,8 +20,8 @@ end module Poly1305 = struct external init : bytes -> string -> unit = "mc_poly1305_init" [@@noalloc] - external update : bytes -> string -> int -> unit = "mc_poly1305_update" [@@noalloc] - external finalize : bytes -> bytes -> unit = "mc_poly1305_finalize" [@@noalloc] + external update : bytes -> string -> int -> int -> unit = "mc_poly1305_update" [@@noalloc] + external finalize : bytes -> bytes -> int -> unit = "mc_poly1305_finalize" [@@noalloc] external ctx_size : unit -> int = "mc_poly1305_ctx_size" [@@noalloc] external mac_size : unit -> int = "mc_poly1305_mac_size" [@@noalloc] end @@ -29,7 +29,7 @@ end module GHASH = struct external keysize : unit -> int = "mc_ghash_key_size" [@@noalloc] external keyinit : string -> bytes -> unit = "mc_ghash_init_key" [@@noalloc] - external ghash : string -> bytes -> string -> int -> unit = "mc_ghash" [@@noalloc] + external ghash : string -> bytes -> string -> int -> int -> unit = "mc_ghash" [@@noalloc] external mode : unit -> int = "mc_ghash_mode" [@@noalloc] end @@ -37,9 +37,9 @@ end * Unsolved: bounds-checked XORs are slowing things down considerably... *) external xor_into_bytes : string -> int -> bytes -> int -> int -> unit = "mc_xor_into_bytes" [@@noalloc] -external count8be : bytes -> bytes -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] -external count16be : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] -external count16be4 : bytes -> bytes -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] +external count8be : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_8_be" [@@noalloc] +external count16be : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_16_be" [@@noalloc] +external count16be4 : ctr:bytes -> bytes -> off:int -> blocks:int -> unit = "mc_count_16_be_4" [@@noalloc] external misc_mode : unit -> int = "mc_misc_mode" [@@noalloc] diff --git a/src/native/ghash_ctmul.c b/src/native/ghash_ctmul.c index 7788fd05..bb1a2b05 100644 --- a/src/native/ghash_ctmul.c +++ b/src/native/ghash_ctmul.c @@ -290,8 +290,8 @@ CAMLprim value mc_ghash_init_key_generic (value key, value m) { return Val_unit; } -CAMLprim value mc_ghash_generic (value m, value hash, value src, value len) { - br_ghash_ctmul(Bp_val(hash), Bp_val(m), _st_uint8(src), Int_val(len)); +CAMLprim value mc_ghash_generic (value m, value hash, value src, value off, value len) { + br_ghash_ctmul(Bp_val(hash), Bp_val(m), _st_uint8_off(src, off), Int_val(len)); return Val_unit; } diff --git a/src/native/ghash_generic.c b/src/native/ghash_generic.c index 2cc49532..68768c85 100644 --- a/src/native/ghash_generic.c +++ b/src/native/ghash_generic.c @@ -101,9 +101,9 @@ CAMLprim value mc_ghash_init_key_generic (value key, value m) { } CAMLprim value -mc_ghash_generic (value m, value hash, value src, value len) { +mc_ghash_generic (value m, value hash, value src, value off, value len) { __ghash ((__uint128_t *) Bp_val (m), (uint64_t *) Bp_val (hash), - _st_uint8 (src), Int_val (len) ); + _st_uint8_off (src, off), Int_val (len) ); return Val_unit; } diff --git a/src/native/ghash_pclmul.c b/src/native/ghash_pclmul.c index 58ca02ea..7c7ea95b 100644 --- a/src/native/ghash_pclmul.c +++ b/src/native/ghash_pclmul.c @@ -204,11 +204,11 @@ CAMLprim value mc_ghash_init_key (value key, value m) { } CAMLprim value -mc_ghash (value k, value hash, value src, value len) { +mc_ghash (value k, value hash, value src, value off, value len) { _mc_switch_accel(pclmul, - mc_ghash_generic(k, hash, src, len), + mc_ghash_generic(k, hash, src, off, len), __ghash ( (__m128i *) Bp_val (k), (__m128i *) Bp_val (hash), - (__m128i *) _st_uint8 (src), Int_val (len) )) + (__m128i *) _st_uint8_off (src, off), Int_val (len) )) return Val_unit; } diff --git a/src/native/mirage_crypto.h b/src/native/mirage_crypto.h index 0542db2f..5496a965 100644 --- a/src/native/mirage_crypto.h +++ b/src/native/mirage_crypto.h @@ -105,7 +105,7 @@ CAMLprim value mc_ghash_key_size_generic (__unit ()); CAMLprim value mc_ghash_init_key_generic (value key, value m); CAMLprim value -mc_ghash_generic (value m, value hash, value src, value len); +mc_ghash_generic (value m, value hash, value src, value off, value len); CAMLprim value mc_xor_into_generic (value b1, value off1, value b2, value off2, value n); @@ -114,6 +114,6 @@ CAMLprim value mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n); CAMLprim value -mc_count_16_be_4_generic (value ctr, value dst, value blocks); +mc_count_16_be_4_generic (value ctr, value dst, value off, value blocks); #endif /* H__MIRAGE_CRYPTO */ diff --git a/src/native/misc.c b/src/native/misc.c index ba9590f8..dea76e18 100644 --- a/src/native/misc.c +++ b/src/native/misc.c @@ -60,9 +60,9 @@ mc_xor_into_bytes_generic (value b1, value off1, value b2, value off2, value n) } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value blocks) { \ + CAMLprim value name (value ctr, value dst, value off, value blocks) { \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8 (dst), Long_val (blocks) ); \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) ); \ return Val_unit; \ } diff --git a/src/native/misc_sse.c b/src/native/misc_sse.c index 1f2265da..c155d468 100644 --- a/src/native/misc_sse.c +++ b/src/native/misc_sse.c @@ -48,11 +48,11 @@ mc_xor_into_bytes (value b1, value off1, value b2, value off2, value n) { } #define __export_counter(name, f) \ - CAMLprim value name (value ctr, value dst, value blocks) { \ - _mc_switch_accel(ssse3, \ - name##_generic (ctr, dst, blocks), \ + CAMLprim value name (value ctr, value dst, value off, value blocks) { \ + _mc_switch_accel(ssse3, \ + name##_generic (ctr, dst, off, blocks), \ f ( (uint64_t*) Bp_val (ctr), \ - (uint64_t*) _bp_uint8 (dst), Long_val (blocks) )) \ + (uint64_t*) _bp_uint8_off (dst, off), Long_val (blocks) )) \ return Val_unit; \ } diff --git a/src/native/poly1305-donna.c b/src/native/poly1305-donna.c index 567649ab..46991dc2 100644 --- a/src/native/poly1305-donna.c +++ b/src/native/poly1305-donna.c @@ -59,13 +59,13 @@ CAMLprim value mc_poly1305_init (value ctx, value key) { return Val_unit; } -CAMLprim value mc_poly1305_update (value ctx, value buf, value len) { - poly1305_update ((poly1305_context *) Bytes_val(ctx), _st_uint8(buf), Int_val(len)); +CAMLprim value mc_poly1305_update (value ctx, value buf, value off, value len) { + poly1305_update ((poly1305_context *) Bytes_val(ctx), _st_uint8_off(buf, off), Int_val(len)); return Val_unit; } -CAMLprim value mc_poly1305_finalize (value ctx, value mac) { - poly1305_finish ((poly1305_context *) Bytes_val(ctx), Bytes_val(mac)); +CAMLprim value mc_poly1305_finalize (value ctx, value mac, value off) { + poly1305_finish ((poly1305_context *) Bytes_val(ctx), _bp_uint8_off(mac, off)); return Val_unit; } diff --git a/src/poly1305.ml b/src/poly1305.ml index eb571b82..0a2cb72d 100644 --- a/src/poly1305.ml +++ b/src/poly1305.ml @@ -11,7 +11,8 @@ module type S = sig val mac : key:string -> string -> string val maci : key:string -> string iter -> string - val macl : key:string -> string list -> string + val mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit + val unsafe_mac_into : key:string -> (string * int * int) list -> bytes -> dst_off:int -> unit end module It : S = struct @@ -31,7 +32,7 @@ module It : S = struct ctx let update ctx data = - P.update ctx data (String.length data) + P.update ctx data 0 (String.length data) let feed ctx cs = let t = dup ctx in @@ -45,7 +46,7 @@ module It : S = struct let final ctx = let res = Bytes.create mac_size in - P.finalize ctx res; + P.finalize ctx res 0; Bytes.unsafe_to_string res let get ctx = final (dup ctx) @@ -54,8 +55,25 @@ module It : S = struct let maci ~key iter = feedi (empty ~key) iter |> final - let macl ~key datas = + let unsafe_mac_into ~key datas dst ~dst_off = let ctx = empty ~key in - List.iter (update ctx) datas; - final ctx + List.iter (fun (d, off, len) -> P.update ctx d off len) datas; + P.finalize ctx dst dst_off + + let mac_into ~key datas dst ~dst_off = + if Bytes.length dst - dst_off < mac_size then + Uncommon.invalid_arg "Poly1305: dst length %u - off %u < len %u" + (Bytes.length dst) dst_off mac_size; + if dst_off < 0 then + Uncommon.invalid_arg "Poly1305: dst_off %u < 0" dst_off; + let ctx = empty ~key in + List.iter (fun (d, off, len) -> + if off < 0 then + Uncommon.invalid_arg "Poly1305: d off %u < 0" off; + if String.length d - off < len then + Uncommon.invalid_arg "Poly1305: d length %u - off %u < len %u" + (String.length d) off len; + P.update ctx d off len) + datas; + P.finalize ctx dst dst_off end