Skip to content

Commit

Permalink
refactor: type conformity
Browse files Browse the repository at this point in the history
  • Loading branch information
AshGw committed May 25, 2024
1 parent 1b1901b commit d90e626
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 56 deletions.
6 changes: 4 additions & 2 deletions src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub const BLOCK_SIZE: usize = 16;
pub const NONCE_SIZE: usize = 12;
pub const TAG_SIZE: usize = 16;
pub const NONCE_SIZE: usize = BLOCK_SIZE - 4;
pub const TAG_SIZE: usize = BLOCK_SIZE;
pub const C_SIZE: usize = TAG_SIZE; // C for common size, ion mean none else
pub const ZEROED_BLOCK: [u8; C_SIZE] = [0u8; C_SIZE];
42 changes: 21 additions & 21 deletions src/gcm.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::constants::{NONCE_SIZE, TAG_SIZE};
use crate::constants::{C_SIZE, NONCE_SIZE, ZEROED_BLOCK};
use crate::ctr::Aes256Ctr32;
use crate::error::Error;
use crate::types::{Bytes, Key, Nonce, Result};
use crate::types::{BlockBytes, Bytes, Key, Nonce, Result};
use aes::cipher::generic_array::GenericArray;
use aes::cipher::{BlockEncrypt, KeyInit};
use aes::Aes256;
Expand All @@ -11,17 +11,17 @@ use ghash::GHash;
#[derive(Clone)]
pub struct GcmGhash {
ghash: GHash,
ghash_padding: [u8; TAG_SIZE],
msg_buffer: [u8; TAG_SIZE],
ghash_padding: BlockBytes,
msg_buffer: BlockBytes,
msg_buffer_offset: usize,
ad_len: usize,
msg_len: usize,
}

impl GcmGhash {
fn new(
h: &[u8; TAG_SIZE],
ghash_padding: [u8; TAG_SIZE],
h: &BlockBytes,
ghash_padding: BlockBytes,
associated_data: &Bytes,
) -> Result<Self> {
let mut ghash = GHash::new(h.into());
Expand All @@ -31,28 +31,28 @@ impl GcmGhash {
Ok(Self {
ghash,
ghash_padding,
msg_buffer: [0u8; TAG_SIZE],
msg_buffer: ZEROED_BLOCK,
msg_buffer_offset: 0,
ad_len: associated_data.len(),
msg_len: 0,
})
}

pub fn update(&mut self, msg: &[u8]) {
pub fn update(&mut self, msg: &Bytes) {
if self.msg_buffer_offset > 0 {
let taking = std::cmp::min(
msg.len(),
TAG_SIZE - self.msg_buffer_offset,
C_SIZE - self.msg_buffer_offset,
);
self.msg_buffer[self.msg_buffer_offset
..self.msg_buffer_offset + taking]
.copy_from_slice(&msg[..taking]);
self.msg_buffer_offset += taking;
assert!(self.msg_buffer_offset <= TAG_SIZE);
assert!(self.msg_buffer_offset <= C_SIZE);

self.msg_len += taking;

if self.msg_buffer_offset == TAG_SIZE {
if self.msg_buffer_offset == C_SIZE {
self.ghash.update(std::slice::from_ref(
ghash::Block::from_slice(&self.msg_buffer),
));
Expand All @@ -66,13 +66,13 @@ impl GcmGhash {
self.msg_len += msg.len();

assert_eq!(self.msg_buffer_offset, 0);
let full_blocks = msg.len() / 16;
let leftover = msg.len() - 16 * full_blocks;
assert!(leftover < TAG_SIZE);
let full_blocks = msg.len() / C_SIZE;
let leftover = msg.len() - C_SIZE * full_blocks;
assert!(leftover < C_SIZE);
if full_blocks > 0 {
let blocks = unsafe {
std::slice::from_raw_parts(
msg[..16 * full_blocks].as_ptr().cast(),
msg[..C_SIZE * full_blocks].as_ptr().cast(),
full_blocks,
)
};
Expand All @@ -84,19 +84,19 @@ impl GcmGhash {
}

self.msg_buffer[0..leftover]
.copy_from_slice(&msg[full_blocks * 16..]);
.copy_from_slice(&msg[full_blocks * C_SIZE..]);
self.msg_buffer_offset = leftover;
assert!(self.msg_buffer_offset < TAG_SIZE);
assert!(self.msg_buffer_offset < C_SIZE);
}

pub fn finalize(mut self) -> [u8; TAG_SIZE] {
pub fn finalize(mut self) -> BlockBytes {
if self.msg_buffer_offset > 0 {
self.ghash.update_padded(
&self.msg_buffer[..self.msg_buffer_offset],
);
}

let mut final_block = [0u8; 16];
let mut final_block = ZEROED_BLOCK;
final_block[..8]
.copy_from_slice(&(8 * self.ad_len as u64).to_be_bytes());
final_block[8..].copy_from_slice(
Expand Down Expand Up @@ -127,12 +127,12 @@ pub fn setup(

let aes256: Aes256 = Aes256::new_from_slice(key)
.map_err(|_| Error::InvalidKeySize)?;
let mut h = [0u8; TAG_SIZE];
let mut h = ZEROED_BLOCK;
aes256.encrypt_block(GenericArray::from_mut_slice(&mut h));

let mut ctr = Aes256Ctr32::new(aes256, nonce, 1)?;

let mut ghash_padding = [0u8; 16];
let mut ghash_padding = ZEROED_BLOCK;
ctr.xor(&mut ghash_padding);

let ghash = GcmGhash::new(&h, ghash_padding, associated_data)?;
Expand Down
69 changes: 36 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ use gcm::{setup as setup_gcm, GcmGhash};
use subtle::ConstantTimeEq;
use types::{Bytes, Key, Nonce, Result};

pub struct Aes256GcmEncryption {
pub struct Aes256Gcm {
ctr: Aes256Ctr32,
ghash: GcmGhash,
}

impl Aes256GcmEncryption {
impl Aes256Gcm {
pub fn new(
key: &Key,
nonce: &Nonce,
Expand All @@ -26,43 +26,44 @@ impl Aes256GcmEncryption {
Ok(Self { ctr, ghash })
}

pub fn encrypt(&mut self, buf: &mut Bytes) {
self.ctr.xor(buf);
self.ghash.update(buf);
}

pub fn compute_tag(self) -> [u8; TAG_SIZE] {
pub fn finalize(self) -> [u8; TAG_SIZE] {
self.ghash.finalize()
}
}

pub struct Aes256GcmDecryption {
ctr: Aes256Ctr32,
ghash: GcmGhash,
pub trait Encrypt {
fn encrypt(&mut self, buf: &mut Bytes);
fn compute_tag(self) -> [u8; TAG_SIZE];
}

impl Aes256GcmDecryption {
pub fn new(
key: &[u8],
nonce: &[u8],
associated_data: &[u8],
) -> Result<Self> {
let (ctr, ghash) = setup_gcm(key, nonce, associated_data)?;
Ok(Self { ctr, ghash })
pub trait Decrypt {
fn decrypt(&mut self, buf: &mut Bytes);
fn verify_tag(self, tag: &Bytes) -> Result<()>;
}

impl Encrypt for Aes256Gcm {
fn encrypt(&mut self, buf: &mut Bytes) {
self.ctr.xor(buf);
self.ghash.update(buf);
}

fn compute_tag(self) -> [u8; TAG_SIZE] {
self.finalize()
}
}

pub fn decrypt(&mut self, buf: &mut [u8]) {
impl Decrypt for Aes256Gcm {
fn decrypt(&mut self, buf: &mut Bytes) {
self.ghash.update(buf);
self.ctr.xor(buf);
}

pub fn verify_tag(self, tag: &[u8]) -> Result<()> {
fn verify_tag(self, tag: &Bytes) -> Result<()> {
if tag.len() != TAG_SIZE {
return Err(Error::InvalidTag);
}

let computed_tag: [u8; 16] = self.ghash.finalize();

let computed_tag = self.finalize();
let tag_ok: subtle::Choice = tag.ct_eq(&computed_tag);

if !bool::from(tag_ok) {
Expand All @@ -84,18 +85,20 @@ mod tests {
let associated_data = b"associated_data";
let plaintext = b"plaintext";

let mut encryption =
Aes256GcmEncryption::new(&key, &nonce, associated_data)
.unwrap();
let mut gcm =
Aes256Gcm::new(&key, &nonce, associated_data).unwrap();

let mut ciphertext = plaintext.to_vec();
encryption.encrypt(&mut ciphertext);
let mut decryption =
Aes256GcmDecryption::new(&key, &nonce, associated_data)
.unwrap();
decryption.decrypt(&mut ciphertext);
gcm.encrypt(&mut ciphertext);

let tag = gcm.compute_tag();

let mut gcm_decrypt =
Aes256Gcm::new(&key, &nonce, associated_data).unwrap();
gcm_decrypt.decrypt(&mut ciphertext);

assert_eq!(&ciphertext, plaintext);
let tag = encryption.compute_tag();
assert!(decryption.verify_tag(&tag).is_ok());
assert!(gcm_decrypt.verify_tag(&tag).is_ok());
}

#[test]
Expand Down
2 changes: 2 additions & 0 deletions src/types.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::constants::C_SIZE;
use crate::error::Error;
use anyhow::Result as AnyRes;

pub type Result<T> = AnyRes<T, Error>;

pub type Bytes = [u8];
pub type BlockBytes = [u8; C_SIZE];
pub type Nonce = Bytes;
pub type Key = Bytes;
pub type CTRInitializer = u32;

0 comments on commit d90e626

Please sign in to comment.