From d6855653166014582642ef61a701884942c7d8a4 Mon Sep 17 00:00:00 2001 From: guorong009 Date: Wed, 7 Aug 2024 22:30:22 +0800 Subject: [PATCH] chore: add unit tests for "Expression*" --- halo2_backend/src/plonk/circuit.rs | 114 ++++++++++++++++++ .../src/plonk/circuit/expression.rs | 8 +- halo2_middleware/src/expression.rs | 96 +++++++++++++++ 3 files changed, 216 insertions(+), 2 deletions(-) diff --git a/halo2_backend/src/plonk/circuit.rs b/halo2_backend/src/plonk/circuit.rs index a754bf4bf..cd7c6c663 100644 --- a/halo2_backend/src/plonk/circuit.rs +++ b/halo2_backend/src/plonk/circuit.rs @@ -366,3 +366,117 @@ fn shuffle_argument_required_degree(arg: &shuffle::Argume // (1 - (l_last + l_blind)) (z(\omega X) (s(X) + \gamma) - z(X) (a(X) + \gamma)) std::cmp::max(2 + shuffle_degree, 2 + input_degree) } + +#[cfg(test)] +mod tests { + use super::{Any, ExpressionBack, QueryBack, VarBack}; + + use halo2_middleware::poly::Rotation; + use halo2curves::bn256::Fr; + + #[test] + fn expressionback_iter_sum() { + let exprs: Vec> = vec![ + ExpressionBack::Constant(1.into()), + ExpressionBack::Constant(2.into()), + ExpressionBack::Constant(3.into()), + ]; + let happened: ExpressionBack = exprs.into_iter().sum(); + let expected: ExpressionBack = ExpressionBack::Sum( + Box::new(ExpressionBack::Sum( + Box::new(ExpressionBack::Constant(1.into())), + Box::new(ExpressionBack::Constant(2.into())), + )), + Box::new(ExpressionBack::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn expressionback_iter_product() { + let exprs: Vec> = vec![ + ExpressionBack::Constant(1.into()), + ExpressionBack::Constant(2.into()), + ExpressionBack::Constant(3.into()), + ]; + let happened: ExpressionBack = exprs.into_iter().product(); + let expected: ExpressionBack = ExpressionBack::Product( + Box::new(ExpressionBack::Product( + Box::new(ExpressionBack::Constant(1.into())), + Box::new(ExpressionBack::Constant(2.into())), + )), + Box::new(ExpressionBack::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn expressionback_identifier() { + let sum_expr: ExpressionBack = ExpressionBack::Sum( + Box::new(ExpressionBack::Constant(1.into())), + Box::new(ExpressionBack::Constant(2.into())), + ); + assert_eq!(sum_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001+0x0000000000000000000000000000000000000000000000000000000000000002)"); + + let prod_expr: ExpressionBack = ExpressionBack::Product( + Box::new(ExpressionBack::Constant(1.into())), + Box::new(ExpressionBack::Constant(2.into())), + ); + assert_eq!(prod_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001*0x0000000000000000000000000000000000000000000000000000000000000002)"); + + // simulate the expressios being used in a circuit + let l: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 0, + column_index: 0, + column_type: Any::Advice, + rotation: Rotation::cur(), + })); + let r: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 1, + column_index: 1, + column_type: Any::Advice, + rotation: Rotation::cur(), + })); + let o: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 2, + column_index: 2, + column_type: Any::Advice, + rotation: Rotation::cur(), + })); + let c: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 3, + column_index: 0, + column_type: Any::Fixed, + rotation: Rotation::cur(), + })); + let sl: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 4, + column_index: 1, + column_type: Any::Fixed, + rotation: Rotation::cur(), + })); + let sr: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 5, + column_index: 2, + column_type: Any::Fixed, + rotation: Rotation::cur(), + })); + let sm: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 6, + column_index: 3, + column_type: Any::Fixed, + rotation: Rotation::cur(), + })); + let so: ExpressionBack = ExpressionBack::Var(VarBack::Query(QueryBack { + index: 7, + column_index: 4, + column_type: Any::Fixed, + rotation: Rotation::cur(), + })); + + let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c; + assert_eq!(simple_plonk_expr.identifier(), "(((((Query(QueryBack { index: 4, column_index: 1, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 0, column_index: 0, column_type: Advice, rotation: Rotation(0) }))+(Query(QueryBack { index: 5, column_index: 2, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 1, column_index: 1, column_type: Advice, rotation: Rotation(0) })))+(Query(QueryBack { index: 6, column_index: 3, column_type: Fixed, rotation: Rotation(0) })*(Query(QueryBack { index: 0, column_index: 0, column_type: Advice, rotation: Rotation(0) })*Query(QueryBack { index: 1, column_index: 1, column_type: Advice, rotation: Rotation(0) }))))+(-(Query(QueryBack { index: 7, column_index: 4, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 2, column_index: 2, column_type: Advice, rotation: Rotation(0) }))))+Query(QueryBack { index: 3, column_index: 0, column_type: Fixed, rotation: Rotation(0) }))"); + } +} diff --git a/halo2_frontend/src/plonk/circuit/expression.rs b/halo2_frontend/src/plonk/circuit/expression.rs index a6560063f..31d53bea3 100644 --- a/halo2_frontend/src/plonk/circuit/expression.rs +++ b/halo2_frontend/src/plonk/circuit/expression.rs @@ -1179,8 +1179,12 @@ mod tests { let sr: Expression = Expression::Selector(Selector(1, false)); let sm: Expression = Expression::Selector(Selector(2, false)); let so: Expression = Expression::Selector(Selector(3, false)); - let c: Expression = Expression::Fixed(FixedQuery { index: None, column_index: 0, rotation: Rotation::cur()}); - + let c: Expression = Expression::Fixed(FixedQuery { + index: None, + column_index: 0, + rotation: Rotation::cur(), + }); + let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c; assert_eq!(simple_plonk_expr.identifier(), "(((((selector[0]*advice[0][0])+(selector[1]*advice[1][0]))+(selector[2]*(advice[0][0]*advice[1][0])))+(-(selector[3]*advice[2][0])))+fixed[0][0])"); } diff --git a/halo2_middleware/src/expression.rs b/halo2_middleware/src/expression.rs index 91d0feb30..e6240d60e 100644 --- a/halo2_middleware/src/expression.rs +++ b/halo2_middleware/src/expression.rs @@ -172,3 +172,99 @@ impl Product for Expression { .unwrap_or(Expression::Constant(F::ONE)) } } + +#[cfg(test)] +mod tests { + + use crate::{ + circuit::{Any, ExpressionMid, QueryMid, VarMid}, + poly::Rotation, + }; + use halo2curves::bn256::Fr; + + #[test] + fn iter_sum() { + let exprs: Vec> = vec![ + ExpressionMid::Constant(1.into()), + ExpressionMid::Constant(2.into()), + ExpressionMid::Constant(3.into()), + ]; + let happened: ExpressionMid = exprs.into_iter().sum(); + let expected: ExpressionMid = ExpressionMid::Sum( + Box::new(ExpressionMid::Sum( + Box::new(ExpressionMid::Constant(1.into())), + Box::new(ExpressionMid::Constant(2.into())), + )), + Box::new(ExpressionMid::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn iter_product() { + let exprs: Vec> = vec![ + ExpressionMid::Constant(1.into()), + ExpressionMid::Constant(2.into()), + ExpressionMid::Constant(3.into()), + ]; + let happened: ExpressionMid = exprs.into_iter().product(); + let expected: ExpressionMid = ExpressionMid::Product( + Box::new(ExpressionMid::Product( + Box::new(ExpressionMid::Constant(1.into())), + Box::new(ExpressionMid::Constant(2.into())), + )), + Box::new(ExpressionMid::Constant(3.into())), + ); + + assert_eq!(happened, expected); + } + + #[test] + fn identifier() { + let sum_expr: ExpressionMid = ExpressionMid::Sum( + Box::new(ExpressionMid::Constant(1.into())), + Box::new(ExpressionMid::Constant(2.into())), + ); + assert_eq!(sum_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001+0x0000000000000000000000000000000000000000000000000000000000000002)"); + + let prod_expr: ExpressionMid = ExpressionMid::Product( + Box::new(ExpressionMid::Constant(1.into())), + Box::new(ExpressionMid::Constant(2.into())), + ); + assert_eq!(prod_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001*0x0000000000000000000000000000000000000000000000000000000000000002)"); + + // simulate the expressios being used in a circuit + let l: ExpressionMid = ExpressionMid::Var(VarMid::Query(QueryMid::new( + Any::Advice, + 0, + Rotation::cur(), + ))); + let r: ExpressionMid = ExpressionMid::Var(VarMid::Query(QueryMid::new( + Any::Advice, + 1, + Rotation::cur(), + ))); + let o: ExpressionMid = ExpressionMid::Var(VarMid::Query(QueryMid::new( + Any::Advice, + 2, + Rotation::cur(), + ))); + let c: ExpressionMid = + ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 0, Rotation::cur()))); + let sl: ExpressionMid = + ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 1, Rotation::cur()))); + let sr: ExpressionMid = + ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 2, Rotation::cur()))); + let sm: ExpressionMid = + ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 3, Rotation::cur()))); + let so: ExpressionMid = + ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 4, Rotation::cur()))); + + let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c; + assert_eq!( + simple_plonk_expr.identifier(), + "(((((f1*a0)+(f2*a1))+(f3*(a0*a1)))+(-(f4*a2)))+f0)" + ); + } +}