Skip to content

Commit 1cd08ef

Browse files
authored
feat(py, core, llvm): add is_borrowed op for BorrowArray (#2610)
Closes #2569
1 parent 4379a0f commit 1cd08ef

File tree

4 files changed

+371
-17
lines changed

4 files changed

+371
-17
lines changed

hugr-core/src/std_extensions/collections/borrow_array.rs

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ pub const BORROW_ARRAY_VALUENAME: TypeName = TypeName::new_inline("borrow_array"
4545
/// Reported unique name of the extension
4646
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.borrow_arr");
4747
/// Extension version.
48-
pub const VERSION: semver::Version = semver::Version::new(0, 1, 1);
48+
pub const VERSION: semver::Version = semver::Version::new(0, 1, 2);
4949

5050
/// A linear, unsafe, fixed-length collection of values.
5151
///
@@ -123,6 +123,8 @@ pub enum BArrayUnsafeOpDef {
123123
discard_all_borrowed,
124124
/// `new_all_borrowed<size, elem_ty>: () -> borrow_array<size, elem_ty>`
125125
new_all_borrowed,
126+
/// is_borrowed<N, T>: borrow_array<N, T>, usize -> bool, borrow_array<N, T>
127+
is_borrowed,
126128
}
127129

128130
impl BArrayUnsafeOpDef {
@@ -166,6 +168,13 @@ impl BArrayUnsafeOpDef {
166168
Self::new_all_borrowed => {
167169
PolyFuncTypeRV::new(params, FuncValueType::new(type_row![], vec![array_ty]))
168170
}
171+
Self::is_borrowed => PolyFuncTypeRV::new(
172+
params,
173+
FuncValueType::new(
174+
vec![array_ty.clone(), usize_t],
175+
vec![crate::extension::prelude::bool_t(), array_ty],
176+
),
177+
),
169178
}
170179
.into()
171180
}
@@ -210,6 +219,7 @@ impl MakeOpDef for BArrayUnsafeOpDef {
210219
"Discard a borrow array where all elements have been borrowed"
211220
}
212221
Self::new_all_borrowed => "Create a new borrow array that contains no elements",
222+
Self::is_borrowed => "Test whether an element in a borrow array has been borrowed",
213223
}
214224
.into()
215225
}
@@ -719,6 +729,38 @@ pub trait BArrayOpBuilder: GenericArrayOpBuilder {
719729
.outputs_arr();
720730
Ok(arr)
721731
}
732+
733+
/// Adds an operation to test whether an element in a borrow array has been borrowed.
734+
///
735+
/// # Arguments
736+
///
737+
/// * `elem_ty` - The type of the elements in the array.
738+
/// * `size` - The size of the array.
739+
/// * `input` - The wire representing the array.
740+
/// * `index` - The wire representing the index to test.
741+
///
742+
/// # Errors
743+
///
744+
/// Returns an error if building the operation fails.
745+
///
746+
/// # Returns
747+
///
748+
/// A tuple containing:
749+
/// * The wire representing the boolean result (true if borrowed).
750+
/// * The wire representing the updated array.
751+
fn add_is_borrowed(
752+
&mut self,
753+
elem_ty: Type,
754+
size: u64,
755+
input: Wire,
756+
index: Wire,
757+
) -> Result<(Wire, Wire), BuildError> {
758+
let op = BArrayUnsafeOpDef::is_borrowed.instantiate(&[size.into(), elem_ty.into()])?;
759+
let [is_borrowed, arr] = self
760+
.add_dataflow_op(op.to_extension_op().unwrap(), vec![input, index])?
761+
.outputs_arr();
762+
Ok((is_borrowed, arr))
763+
}
722764
}
723765

724766
impl<D: Dataflow> BArrayOpBuilder for D {}
@@ -804,4 +846,25 @@ mod test {
804846
builder.finish_hugr_with_outputs([arr_with_put]).unwrap()
805847
};
806848
}
849+
#[test]
850+
fn test_is_borrowed() {
851+
let size = 4;
852+
let elem_ty = qb_t();
853+
let arr_ty = borrow_array_type(size, elem_ty.clone());
854+
855+
let mut builder =
856+
DFGBuilder::new(Signature::new(vec![arr_ty.clone()], vec![qb_t(), arr_ty])).unwrap();
857+
let idx = builder.add_load_value(ConstUsize::new(2));
858+
let [arr] = builder.input_wires_arr();
859+
// Borrow the element at index 2
860+
let (qb, arr_with_borrowed) = builder
861+
.add_borrow_array_borrow(elem_ty.clone(), size, arr, idx)
862+
.unwrap();
863+
let (_is_borrowed, arr_after_check) = builder
864+
.add_is_borrowed(elem_ty.clone(), size, arr_with_borrowed, idx)
865+
.unwrap();
866+
builder
867+
.finish_hugr_with_outputs([qb, arr_after_check])
868+
.unwrap();
869+
}
807870
}

hugr-llvm/src/extension/collections/borrow_array.rs

Lines changed: 145 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -599,20 +599,16 @@ impl MaskCheck {
599599
None,
600600
|ctx, [mask_ptr, idx]| {
601601
// Compute mask bitarray block index via `idx // BLOCK_SIZE`
602-
let mask_ptr = mask_ptr.into_pointer_value();
603-
let idx = idx.into_int_value();
604602
let usize_t = usize_ty(&ctx.typing_session());
605-
let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false);
606-
let builder = ctx.builder();
607-
let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?;
608-
let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? };
609-
let block = builder.build_load(block_ptr, "")?.into_int_value();
610-
611-
// Extract bit from the block at position `idx % BLOCK_SIZE`
612-
let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?;
613-
let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?;
614-
let bit =
615-
builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?;
603+
let (
604+
BlockData {
605+
block_ptr,
606+
block,
607+
idx_in_block,
608+
},
609+
bit,
610+
) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?;
611+
616612
let panic_bb = ctx.build_positioned_new_block("panic", None, |ctx, panic_bb| {
617613
let err: &ConstError = match self {
618614
MaskCheck::CheckNotBorrowed | MaskCheck::Borrow => &ERR_ALREADY_BORROWED,
@@ -651,6 +647,38 @@ impl MaskCheck {
651647
}
652648
}
653649

650+
struct BlockData<'c> {
651+
block_ptr: PointerValue<'c>,
652+
block: IntValue<'c>,
653+
idx_in_block: IntValue<'c>,
654+
}
655+
656+
fn inspect_mask_idx_bit<'c, H: HugrView<Node = Node>>(
657+
ctx: &mut EmitFuncContext<'c, '_, H>,
658+
mask_ptr: BasicValueEnum<'c>,
659+
idx: BasicValueEnum<'c>,
660+
) -> Result<(BlockData<'c>, IntValue<'c>)> {
661+
let usize_t = usize_ty(&ctx.typing_session());
662+
let mask_ptr = mask_ptr.into_pointer_value();
663+
let idx = idx.into_int_value();
664+
let block_size = usize_t.const_int(usize_t.get_bit_width() as u64, false);
665+
let builder = ctx.builder();
666+
let block_idx = builder.build_int_unsigned_div(idx, block_size, "")?;
667+
let block_ptr = unsafe { builder.build_in_bounds_gep(mask_ptr, &[block_idx], "")? };
668+
let block = builder.build_load(block_ptr, "")?.into_int_value();
669+
let idx_in_block = builder.build_int_unsigned_rem(idx, block_size, "")?;
670+
let block_shifted = builder.build_right_shift(block, idx_in_block, false, "")?;
671+
let bit = builder.build_int_truncate(block_shifted, ctx.iw_context().bool_type(), "")?;
672+
Ok((
673+
BlockData {
674+
block_ptr,
675+
block,
676+
idx_in_block,
677+
},
678+
bit,
679+
))
680+
}
681+
654682
struct MaskInfo<'a> {
655683
mask_ptr: PointerValue<'a>,
656684
offset: IntValue<'a>,
@@ -787,6 +815,27 @@ fn build_mask_padding1d<'c, H: HugrView<Node = Node>>(
787815
Ok(())
788816
}
789817

818+
/// Emits a check that returns whether a specific array element is borrowed (true) or not (false).
819+
pub fn build_is_borrowed_bit<'c, H: HugrView<Node = Node>>(
820+
ctx: &mut EmitFuncContext<'c, '_, H>,
821+
mask_ptr: PointerValue<'c>,
822+
idx: IntValue<'c>,
823+
) -> Result<inkwell::values::IntValue<'c>> {
824+
// Wrap the check into a function instead of inlining
825+
const FUNC_NAME: &str = "__barray_is_borrowed";
826+
get_or_make_function(
827+
ctx,
828+
FUNC_NAME,
829+
[mask_ptr.into(), idx.into()],
830+
Some(ctx.iw_context().bool_type().into()),
831+
|ctx, [mask_ptr, idx]| {
832+
let (_, bit) = inspect_mask_idx_bit(ctx, mask_ptr, idx)?;
833+
Ok(Some(bit.into()))
834+
},
835+
)
836+
.map(|v| v.expect("i1 return value").into_int_value())
837+
}
838+
790839
/// Emits a check that no array elements have been borrowed.
791840
pub fn build_none_borrowed_check<'c, H: HugrView<Node = Node>>(
792841
ccg: &impl BorrowArrayCodegen,
@@ -1570,6 +1619,20 @@ pub fn emit_barray_unsafe_op<'c, H: HugrView<Node = Node>>(
15701619
let (_, array_v) = build_barray_alloc(ctx, ccg, elem_ty, size, true)?;
15711620
outputs.finish(ctx.builder(), [array_v.into()])
15721621
}
1622+
BArrayUnsafeOpDef::is_borrowed => {
1623+
let [array_v, index_v] = inputs
1624+
.try_into()
1625+
.map_err(|_| anyhow!("BArrayUnsafeOpDef::is_borrowed expects two arguments"))?;
1626+
let BArrayFatPtrComponents {
1627+
mask_ptr, offset, ..
1628+
} = decompose_barray_fat_pointer(builder, array_v)?;
1629+
let index_v = index_v.into_int_value();
1630+
build_bounds_check(ccg, ctx, size, index_v)?;
1631+
let offset_index_v = ctx.builder().build_int_add(index_v, offset, "")?;
1632+
// let bit = build_is_borrowed_check(ctx, mask_ptr, offset_index_v)?;
1633+
let bit = build_is_borrowed_bit(ctx, mask_ptr, offset_index_v)?;
1634+
outputs.finish(ctx.builder(), [bit.into(), array_v])
1635+
}
15731636
_ => todo!(),
15741637
}
15751638
}
@@ -1627,6 +1690,8 @@ mod test {
16271690
use hugr_core::extension::prelude::either_type;
16281691
use hugr_core::ops::Tag;
16291692
use hugr_core::std_extensions::STD_REG;
1693+
use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef;
1694+
use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef;
16301695
use hugr_core::std_extensions::collections::array::ArrayOpBuilder;
16311696
use hugr_core::std_extensions::collections::array::op_builder::build_all_borrow_array_ops;
16321697
use hugr_core::std_extensions::collections::borrow_array::{
@@ -2634,7 +2699,7 @@ mod test {
26342699
// - Pops specified numbers from the left to introduce an offset
26352700
// - Converts it into a regular array
26362701
// - Converts it back into a borrow array
2637-
// - Borrows alls elements, sums them up, and returns the sum
2702+
// - Borrows all elements, sums them up, and returns the sum
26382703

26392704
let int_ty = int_type(6);
26402705
let hugr = SimpleHugrConfig::new()
@@ -2908,4 +2973,70 @@ mod test {
29082973
let msg = "Some array elements have been borrowed";
29092974
assert_eq!(&exec_ctx.exec_hugr_panicking(hugr, "main"), msg);
29102975
}
2976+
2977+
#[rstest]
2978+
fn exec_is_borrowed_basic(mut exec_ctx: TestContext) {
2979+
// We build a HUGR that:
2980+
// - Creates a borrow array [1,2,3]
2981+
// - Borrows index 1
2982+
// - Checks is_borrowed for indices 0, 1
2983+
// - Returns 1 if [false, true], else 0
2984+
let int_ty = int_type(6);
2985+
let size = 3;
2986+
let hugr = SimpleHugrConfig::new()
2987+
.with_outs(int_ty.clone())
2988+
.with_extensions(exec_registry())
2989+
.finish(|mut builder| {
2990+
let barray = borrow_array::BArrayValue::new(
2991+
int_ty.clone(),
2992+
(1..=3)
2993+
.map(|i| ConstInt::new_u(6, i).unwrap().into())
2994+
.collect_vec(),
2995+
);
2996+
let barray = builder.add_load_value(barray);
2997+
let idx1 = builder.add_load_value(ConstUsize::new(1));
2998+
let (_, barray) = builder
2999+
.add_borrow_array_borrow(int_ty.clone(), size, barray, idx1)
3000+
.unwrap();
3001+
3002+
let idx0 = builder.add_load_value(ConstUsize::new(0));
3003+
let (arr, b0_bools) =
3004+
[idx0, idx1]
3005+
.iter()
3006+
.fold((barray, Vec::new()), |(arr, mut bools), idx| {
3007+
let (b, arr) = builder
3008+
.add_is_borrowed(int_ty.clone(), size, arr, *idx)
3009+
.unwrap();
3010+
bools.push(b);
3011+
(arr, bools)
3012+
});
3013+
let [b0, b1] = b0_bools.try_into().unwrap();
3014+
3015+
let b0 = builder.add_not(b0).unwrap(); // flip b0 to true
3016+
let and01 = builder.add_and(b0, b1).unwrap();
3017+
// convert bool to i1
3018+
let i1 = builder
3019+
.add_dataflow_op(ConvertOpDef::ifrombool.without_log_width(), [and01])
3020+
.unwrap()
3021+
.out_wire(0);
3022+
// widen i1 to i64
3023+
let i_64 = builder
3024+
.add_dataflow_op(IntOpDef::iwiden_u.with_two_log_widths(0, 6), [i1])
3025+
.unwrap()
3026+
.out_wire(0);
3027+
builder
3028+
.add_borrow_array_discard(int_ty.clone(), size, arr)
3029+
.unwrap();
3030+
builder.finish_hugr_with_outputs([i_64]).unwrap()
3031+
});
3032+
3033+
exec_ctx.add_extensions(|cge| {
3034+
cge.add_default_prelude_extensions()
3035+
.add_logic_extensions()
3036+
.add_conversion_extensions()
3037+
.add_default_borrow_array_extensions(DefaultPreludeCodegen)
3038+
.add_default_int_extensions()
3039+
});
3040+
assert_eq!(1, exec_ctx.exec_hugr_u64(hugr, "main"));
3041+
}
29113042
}

0 commit comments

Comments
 (0)