Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add async constructors for GitOid. #110

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gitoid/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ paste = "1.0.14"
sha1 = { version = "0.10.6", default-features = false, features = ["std"] }
sha1collisiondetection = "0.3.3"
sha2 = { version = "0.10.8", default-features = false }
tokio = { version = "1.36.0", features = ["io-util"] }
url = "2.4.1"
131 changes: 121 additions & 10 deletions gitoid/src/gitoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ use std::io::BufReader;
use std::io::Read;
use std::io::Seek;
use std::io::SeekFrom;
use tokio::io::AsyncBufReadExt as _;
use tokio::io::AsyncRead;
use tokio::io::AsyncSeek;
use tokio::io::AsyncSeekExt as _;
use tokio::io::BufReader as AsyncBufReader;
use url::Url;

/// A struct that computes [gitoids][g] based on the selected algorithm
Expand All @@ -49,6 +54,14 @@ where
H: HashAlgorithm,
O: ObjectType,
{
/// Helper function to construct GitOid from raw hash.
fn from_hash(arr: GenericArray<u8, H::OutputSize>) -> GitOid<H, O> {
GitOid {
_phantom: PhantomData,
value: H::array_from_generic(arr),
}
}

/// Create a new `GitOid` based on a slice of bytes.
pub fn from_bytes<B: AsRef<[u8]>>(content: B) -> GitOid<H, O> {
fn inner<H, O>(content: &[u8]) -> GitOid<H, O>
Expand Down Expand Up @@ -80,17 +93,33 @@ where
/// Create a `GitOid` from a reader.
pub fn from_reader<R: Read + Seek>(mut reader: R) -> Result<GitOid<H, O>> {
let expected_length = stream_len(&mut reader)? as usize;
GitOid::from_reader_with_expected_length(reader, expected_length)
GitOid::from_reader_with_length(reader, expected_length)
}

/// Generate a `GitOid` from a reader, providing an expected length in bytes.
pub fn from_reader_with_expected_length<R: Read>(
pub fn from_reader_with_length<R: Read>(
reader: R,
expected_length: usize,
) -> Result<GitOid<H, O>> {
gitoid_from_buffer(H::new(), reader, expected_length)
}

/// Generate a `GitOid` from an asynchronous reader.
pub async fn from_async_reader<R: AsyncRead + AsyncSeek + Unpin>(
mut reader: R,
) -> Result<GitOid<H, O>> {
let expected_length = async_stream_len(&mut reader).await? as usize;
GitOid::from_async_reader_with_length(reader, expected_length).await
}

/// Generate a `GitOid` from an asynchronous reader, providing an expected length in bytes.
pub async fn from_async_reader_with_length<R: AsyncRead + Unpin>(
reader: R,
expected_length: usize,
) -> Result<GitOid<H, O>> {
gitoid_from_async_buffer(H::new(), reader, expected_length).await
}

/// Construct a new `GitOid` from a `Url`.
pub fn from_url(url: Url) -> Result<GitOid<H, O>> {
GitOid::try_from(url)
Expand Down Expand Up @@ -275,10 +304,7 @@ where
.and_then(|_| self.validate_object_type())
.and_then(|_| self.validate_hash_algorithm())
.and_then(|_| self.parse_hash())
.map(|arr| GitOid {
_phantom: PhantomData,
value: H::array_from_generic(arr),
})
.map(GitOid::from_hash)
}

fn validate_url_scheme(&self) -> Result<()> {
Expand Down Expand Up @@ -395,10 +421,7 @@ where
});
}

Ok(GitOid {
_phantom: PhantomData,
value: H::array_from_generic(hash),
})
Ok(GitOid::from_hash(hash))
}

// Helper extension trait to give a convenient way to iterate over
Expand Down Expand Up @@ -455,6 +478,77 @@ where
Ok((hash, amount_read))
}

/// Async version of `gitoid_from_buffer`.
async fn gitoid_from_async_buffer<H, O, R>(
digester: H,
reader: R,
expected_read_length: usize,
) -> Result<GitOid<H, O>>
where
H: HashAlgorithm,
O: ObjectType,
R: AsyncRead + Unpin,
{
let expected_hash_length = <H as OutputSizeUser>::output_size();
let (hash, amount_read) =
hash_from_async_buffer::<H, O, R>(digester, reader, expected_read_length).await?;

if amount_read != expected_read_length {
return Err(Error::UnexpectedHashLength {
expected: expected_read_length,
observed: amount_read,
});
}

if hash.len() != expected_hash_length {
return Err(Error::UnexpectedHashLength {
expected: expected_hash_length,
observed: hash.len(),
});
}

Ok(GitOid::from_hash(hash))
}

/// Async version of `hash_from_buffer`.
async fn hash_from_async_buffer<H, O, R>(
mut digester: H,
reader: R,
expected_read_length: usize,
) -> Result<(GenericArray<u8, H::OutputSize>, usize)>
where
H: HashAlgorithm,
O: ObjectType,
R: AsyncRead + Unpin,
{
digester.update(format_bytes!(
b"{} {}\0",
O::NAME.as_bytes(),
expected_read_length
));

let mut reader = AsyncBufReader::new(reader);

let mut total_read = 0;

loop {
let buffer = reader.fill_buf().await?;
let amount_read = buffer.len();

if amount_read == 0 {
break;
}

digester.update(buffer);

reader.consume(amount_read);
total_read += amount_read;
}

let hash = digester.finalize();
Ok((hash, total_read))
}

// Adapted from the Rust standard library's unstable implementation
// of `Seek::stream_len`.
//
Expand Down Expand Up @@ -500,3 +594,20 @@ where

Ok(len)
}

/// An async equivalent of `stream_len`.
async fn async_stream_len<R>(mut stream: R) -> Result<u64>
where
R: AsyncSeek + Unpin,
{
let old_pos = stream.stream_position().await?;
let len = stream.seek(SeekFrom::End(0)).await?;

// Avoid seeking a third time when we were already at the end of the
// stream. The branch is usually way cheaper than a seek operation.
if old_pos != len {
stream.seek(SeekFrom::Start(old_pos)).await?;
}

Ok(len)
}