Skip to content

Commit

Permalink
Add missing null checks in ManageMemory impl
Browse files Browse the repository at this point in the history
Move some tests into documentation

Signed-off-by: Tin Švagelj <[email protected]>
  • Loading branch information
Caellian committed Mar 22, 2024
1 parent 11cc569 commit d964574
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 92 deletions.
3 changes: 2 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use crate::{range::ByteRange, reference::BorrowState};
/// allocator failure.
#[derive(Debug, Clone, Copy)]
pub enum MemoryError {
/// Tried allocating container capacity larger than `isize::MAX`
/// Tried allocating memory chunk larger than [`isize::MAX`] or what is
/// currently available.
TooLarge,
/// Allocation failure caused by either resource exhaustion or invalid
/// arguments being provided to an allocator.
Expand Down
11 changes: 3 additions & 8 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,16 +463,11 @@ impl<Impl: ImplDetails<A>, A: ManageMemory> ContiguousMemory<Impl, A> {
/// use contiguous_mem::ContiguousMemory;
///
/// let mut s: ContiguousMemory = ContiguousMemory::new();
///
///
/// assert!(s.try_grow_to(1024).is_ok());
///
/// let el_count: usize = 42;
/// let el_size: usize = 288230376151711744; // bad read?
///
/// let mut required_size: usize = 1024;
/// for i in 0..el_count {
/// required_size += el_size;
/// }
/// let required_size: usize = usize::MAX; // bad read?
/// // can't allocate all addressable memory
/// assert!(s.try_grow_to(required_size).is_err());
/// ```
pub fn try_grow_to(&mut self, new_capacity: usize) -> Result<Option<MemoryBase>, MemoryError> {
Expand Down
142 changes: 78 additions & 64 deletions src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,18 @@ pub struct SegmentTracker {
}

impl SegmentTracker {
/// Constructs a new `SegmentTracker` of the provided `size`.
/// Constructs a new empty `SegmentTracker` of the provided `size`.
///
/// # Examples
/// ```
/// # use contiguous_mem::memory::SegmentTracker;
/// # use contiguous_mem::range::ByteRange;
/// let tracker = SegmentTracker::new(1024);
///
/// assert!(!tracker.is_full());
/// assert_eq!(tracker.size(), 1024);
/// assert_eq!(tracker.whole_range(), ByteRange(0, 1024));
/// ```
pub fn new(size: usize) -> Self {
SegmentTracker {
size,
Expand Down Expand Up @@ -220,6 +231,19 @@ impl SegmentTracker {
///
/// It returns a [`ByteRange`] of the memory region that was marked as used
/// if successful, otherwise `None`
///
/// # Examples
///
/// ```
/// # use contiguous_mem::range::ByteRange;
/// # use contiguous_mem::memory::{alloc::Layout, SegmentTracker};
/// let mut tracker = SegmentTracker::new(1024);
///
/// let layout = Layout::from_size_align(128, 8).unwrap();
/// let range = tracker.take_next(8, layout).unwrap();
///
/// assert_eq!(range, ByteRange(0, 128));
/// ```
#[inline]
pub fn take_next(&mut self, base_pos: usize, layout: impl HasLayout) -> Option<ByteRange> {
let mut location = self.peek_next(base_pos, layout)?;
Expand All @@ -235,6 +259,21 @@ impl SegmentTracker {
/// * the provided region falls outside of the memory tracked by the
/// `SegmentTracker`, or
/// * the provided region is in part or whole already marked as free.
///
/// # Examples
/// ```
/// # use contiguous_mem::range::ByteRange;
/// # use contiguous_mem::memory::{alloc::Layout, SegmentTracker};
/// let mut tracker = SegmentTracker::new(1024);
///
/// let range = tracker
/// .take_next(8, Layout::from_size_align(32, 8).unwrap())
/// .unwrap();
/// assert_eq!(range, ByteRange(0, 32));
///
/// tracker.release(range);
/// assert!(!tracker.is_full());
/// ```
pub fn release(&mut self, region: ByteRange) {
if region.is_empty() {
return;
Expand Down Expand Up @@ -283,41 +322,6 @@ impl core::fmt::Debug for SegmentTracker {
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn new_allocation_tracker() {
let tracker = SegmentTracker::new(1024);
assert_eq!(tracker.size(), 1024);
assert!(!tracker.is_full());
assert_eq!(tracker.whole_range(), ByteRange(0, 1024));
}

#[test]
fn take_and_release_allocation_tracker() {
let mut tracker = SegmentTracker::new(1024);

let range = tracker
.take_next(8, Layout::from_size_align(32, 8).unwrap())
.unwrap();
assert_eq!(range, ByteRange(0, 32));

tracker.release(range);
assert!(!tracker.is_full());
}

#[test]
fn take_next_allocation_tracker() {
let mut tracker = SegmentTracker::new(1024);

let layout = Layout::from_size_align(128, 8).unwrap();
let range = tracker.take_next(8, layout).unwrap();
assert_eq!(range, ByteRange(0, 128));
}
}

/// A result of [`SegmentTracker::peek_next`] which contains information about
/// available allocation slot and wherein a certain [`Layout`] could be placed.
///
Expand Down Expand Up @@ -456,22 +460,27 @@ pub trait ManageMemory {
unsafe fn grow(&self, base: MemoryBase, new_size: usize) -> Result<BaseAddress, MemoryError>;
}

unsafe fn some_non_null_slice(data: *const u8, len: usize) -> Option<NonNull<[u8]>> {
Some(NonNull::from(core::slice::from_raw_parts(data, len)))
}

/// Default [memory manager](ManageMemory) that uses the methods exposed by
/// [`alloc`] module.
#[derive(Clone, Copy)]
pub struct DefaultMemoryManager;
impl ManageMemory for DefaultMemoryManager {
fn allocate(&self, layout: Layout) -> Result<BaseAddress, MemoryError> {
if layout.size() == 0 {
Ok(None)
Ok(if layout.size() == 0 {
None
} else {
unsafe {
Ok(Some(NonNull::from(core::slice::from_raw_parts(
alloc::alloc(layout),
layout.size(),
))))
let data = alloc::alloc(layout);
if data.is_null() {
return Err(MemoryError::TooLarge);
}
some_non_null_slice(data, layout.size())
}
}
})
}

unsafe fn deallocate(&self, base: MemoryBase) {
Expand All @@ -484,40 +493,45 @@ impl ManageMemory for DefaultMemoryManager {
}

unsafe fn shrink(&self, base: MemoryBase, new_size: usize) -> Result<BaseAddress, MemoryError> {
match base.address {
Some(it) => Ok({
Ok(match base.address {
Some(it) => {
if new_size > 0 {
Some(NonNull::from(core::slice::from_raw_parts(
alloc::realloc(it.as_ptr() as *mut u8, base.layout(), new_size),
new_size,
)))
let data = alloc::realloc(it.as_ptr() as *mut u8, base.layout(), new_size);
if data.is_null() {
return Err(MemoryError::TooLarge);
}
some_non_null_slice(data, new_size)
} else {
alloc::dealloc(it.as_ptr() as *mut u8, base.layout());
None
}
}),
None => Ok(None),
}
}
None => None,
})
}

unsafe fn grow(&self, base: MemoryBase, new_size: usize) -> Result<BaseAddress, MemoryError> {
match base.address {
Some(it) => Ok(Some(NonNull::from(core::slice::from_raw_parts(
alloc::realloc(it.as_ptr() as *mut u8, base.layout(), new_size),
new_size,
)))),
None => Ok({
Ok(match base.address {
Some(it) => {
let data = alloc::realloc(it.as_ptr() as *mut u8, base.layout(), new_size);
if data.is_null() {
return Err(MemoryError::TooLarge);
}
some_non_null_slice(data, new_size)
}
None => {
if new_size == 0 {
None
} else {
let new_layout = Layout::from_size_align(new_size, base.alignment())?;
Some(NonNull::from(core::slice::from_raw_parts(
alloc::alloc(new_layout),
new_size,
)))
let data = alloc::alloc(new_layout);
if data.is_null() {
return Err(MemoryError::TooLarge);
}
some_non_null_slice(data, new_size)
}
}),
}
}
})
}
}

Expand Down
36 changes: 17 additions & 19 deletions src/range.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,23 @@ impl ByteRange {

/// Merges this byte range with `other` and returns a byte range that
/// contains both.
///
/// # Example
///
/// ```
/// # use contiguous_mem::range::ByteRange;
/// let a = ByteRange::new_unchecked(0, 10);
/// let b = ByteRange::new_unchecked(10, 20);
///
/// let added_seq = a.union_unchecked(b);
/// assert_eq!(added_seq.0, 0);
/// assert_eq!(added_seq.1, 20);
///
/// // range union is symmetrical
/// let added_seq_rev = b.union_unchecked(a);
/// assert_eq!(added_seq_rev.0, 0);
/// assert_eq!(added_seq_rev.1, 20);
/// ```
pub fn union_unchecked(&self, other: Self) -> Self {
ByteRange(self.0.min(other.0), self.1.max(other.1))
}
Expand Down Expand Up @@ -130,22 +147,3 @@ impl Display for ByteRange {
write!(f, "[{:x}, {:x})", self.0, self.1)
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn byterange_merging_works() {
let a = ByteRange::new_unchecked(0, 10);
let b = ByteRange::new_unchecked(10, 20);

let added_seq = a.union_unchecked(b);
assert_eq!(added_seq.0, 0);
assert_eq!(added_seq.1, 20);

let added_seq_rev = b.union_unchecked(a);
assert_eq!(added_seq_rev.0, 0);
assert_eq!(added_seq_rev.1, 20);
}
}

0 comments on commit d964574

Please sign in to comment.