Skip to content

Commit

Permalink
chore: add unit tests for "Expression*"
Browse files Browse the repository at this point in the history
  • Loading branch information
guorong009 committed Aug 7, 2024
1 parent d665a5e commit d685565
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 2 deletions.
114 changes: 114 additions & 0 deletions halo2_backend/src/plonk/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,117 @@ fn shuffle_argument_required_degree<F: Field, V: Variable>(arg: &shuffle::Argume
// (1 - (l_last + l_blind)) (z(\omega X) (s(X) + \gamma) - z(X) (a(X) + \gamma))
std::cmp::max(2 + shuffle_degree, 2 + input_degree)
}

#[cfg(test)]
mod tests {
use super::{Any, ExpressionBack, QueryBack, VarBack};

use halo2_middleware::poly::Rotation;
use halo2curves::bn256::Fr;

#[test]
fn expressionback_iter_sum() {
let exprs: Vec<ExpressionBack<Fr>> = vec![
ExpressionBack::Constant(1.into()),
ExpressionBack::Constant(2.into()),
ExpressionBack::Constant(3.into()),
];
let happened: ExpressionBack<Fr> = exprs.into_iter().sum();
let expected: ExpressionBack<Fr> = ExpressionBack::Sum(
Box::new(ExpressionBack::Sum(
Box::new(ExpressionBack::Constant(1.into())),
Box::new(ExpressionBack::Constant(2.into())),
)),
Box::new(ExpressionBack::Constant(3.into())),
);

assert_eq!(happened, expected);
}

#[test]
fn expressionback_iter_product() {
let exprs: Vec<ExpressionBack<Fr>> = vec![
ExpressionBack::Constant(1.into()),
ExpressionBack::Constant(2.into()),
ExpressionBack::Constant(3.into()),
];
let happened: ExpressionBack<Fr> = exprs.into_iter().product();
let expected: ExpressionBack<Fr> = ExpressionBack::Product(
Box::new(ExpressionBack::Product(
Box::new(ExpressionBack::Constant(1.into())),
Box::new(ExpressionBack::Constant(2.into())),
)),
Box::new(ExpressionBack::Constant(3.into())),
);

assert_eq!(happened, expected);
}

#[test]
fn expressionback_identifier() {
let sum_expr: ExpressionBack<Fr> = ExpressionBack::Sum(
Box::new(ExpressionBack::Constant(1.into())),
Box::new(ExpressionBack::Constant(2.into())),
);
assert_eq!(sum_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001+0x0000000000000000000000000000000000000000000000000000000000000002)");

let prod_expr: ExpressionBack<Fr> = ExpressionBack::Product(
Box::new(ExpressionBack::Constant(1.into())),
Box::new(ExpressionBack::Constant(2.into())),
);
assert_eq!(prod_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001*0x0000000000000000000000000000000000000000000000000000000000000002)");

// simulate the expressios being used in a circuit
let l: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 0,
column_index: 0,
column_type: Any::Advice,
rotation: Rotation::cur(),
}));
let r: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 1,
column_index: 1,
column_type: Any::Advice,
rotation: Rotation::cur(),
}));
let o: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 2,
column_index: 2,
column_type: Any::Advice,
rotation: Rotation::cur(),
}));
let c: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 3,
column_index: 0,
column_type: Any::Fixed,
rotation: Rotation::cur(),
}));
let sl: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 4,
column_index: 1,
column_type: Any::Fixed,
rotation: Rotation::cur(),
}));
let sr: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 5,
column_index: 2,
column_type: Any::Fixed,
rotation: Rotation::cur(),
}));
let sm: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 6,
column_index: 3,
column_type: Any::Fixed,
rotation: Rotation::cur(),
}));
let so: ExpressionBack<Fr> = ExpressionBack::Var(VarBack::Query(QueryBack {
index: 7,
column_index: 4,
column_type: Any::Fixed,
rotation: Rotation::cur(),
}));

let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c;
assert_eq!(simple_plonk_expr.identifier(), "(((((Query(QueryBack { index: 4, column_index: 1, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 0, column_index: 0, column_type: Advice, rotation: Rotation(0) }))+(Query(QueryBack { index: 5, column_index: 2, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 1, column_index: 1, column_type: Advice, rotation: Rotation(0) })))+(Query(QueryBack { index: 6, column_index: 3, column_type: Fixed, rotation: Rotation(0) })*(Query(QueryBack { index: 0, column_index: 0, column_type: Advice, rotation: Rotation(0) })*Query(QueryBack { index: 1, column_index: 1, column_type: Advice, rotation: Rotation(0) }))))+(-(Query(QueryBack { index: 7, column_index: 4, column_type: Fixed, rotation: Rotation(0) })*Query(QueryBack { index: 2, column_index: 2, column_type: Advice, rotation: Rotation(0) }))))+Query(QueryBack { index: 3, column_index: 0, column_type: Fixed, rotation: Rotation(0) }))");
}
}
8 changes: 6 additions & 2 deletions halo2_frontend/src/plonk/circuit/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1179,8 +1179,12 @@ mod tests {
let sr: Expression<Fr> = Expression::Selector(Selector(1, false));
let sm: Expression<Fr> = Expression::Selector(Selector(2, false));
let so: Expression<Fr> = Expression::Selector(Selector(3, false));
let c: Expression<Fr> = Expression::Fixed(FixedQuery { index: None, column_index: 0, rotation: Rotation::cur()});

let c: Expression<Fr> = Expression::Fixed(FixedQuery {
index: None,
column_index: 0,
rotation: Rotation::cur(),
});

let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c;
assert_eq!(simple_plonk_expr.identifier(), "(((((selector[0]*advice[0][0])+(selector[1]*advice[1][0]))+(selector[2]*(advice[0][0]*advice[1][0])))+(-(selector[3]*advice[2][0])))+fixed[0][0])");
}
Expand Down
96 changes: 96 additions & 0 deletions halo2_middleware/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,99 @@ impl<F: Field, V: Variable> Product<Self> for Expression<F, V> {
.unwrap_or(Expression::Constant(F::ONE))
}
}

#[cfg(test)]
mod tests {

use crate::{
circuit::{Any, ExpressionMid, QueryMid, VarMid},
poly::Rotation,
};
use halo2curves::bn256::Fr;

#[test]
fn iter_sum() {
let exprs: Vec<ExpressionMid<Fr>> = vec![
ExpressionMid::Constant(1.into()),
ExpressionMid::Constant(2.into()),
ExpressionMid::Constant(3.into()),
];
let happened: ExpressionMid<Fr> = exprs.into_iter().sum();
let expected: ExpressionMid<Fr> = ExpressionMid::Sum(
Box::new(ExpressionMid::Sum(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
)),
Box::new(ExpressionMid::Constant(3.into())),
);

assert_eq!(happened, expected);
}

#[test]
fn iter_product() {
let exprs: Vec<ExpressionMid<Fr>> = vec![
ExpressionMid::Constant(1.into()),
ExpressionMid::Constant(2.into()),
ExpressionMid::Constant(3.into()),
];
let happened: ExpressionMid<Fr> = exprs.into_iter().product();
let expected: ExpressionMid<Fr> = ExpressionMid::Product(
Box::new(ExpressionMid::Product(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
)),
Box::new(ExpressionMid::Constant(3.into())),
);

assert_eq!(happened, expected);
}

#[test]
fn identifier() {
let sum_expr: ExpressionMid<Fr> = ExpressionMid::Sum(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
);
assert_eq!(sum_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001+0x0000000000000000000000000000000000000000000000000000000000000002)");

let prod_expr: ExpressionMid<Fr> = ExpressionMid::Product(
Box::new(ExpressionMid::Constant(1.into())),
Box::new(ExpressionMid::Constant(2.into())),
);
assert_eq!(prod_expr.identifier(), "(0x0000000000000000000000000000000000000000000000000000000000000001*0x0000000000000000000000000000000000000000000000000000000000000002)");

// simulate the expressios being used in a circuit
let l: ExpressionMid<Fr> = ExpressionMid::Var(VarMid::Query(QueryMid::new(
Any::Advice,
0,
Rotation::cur(),
)));
let r: ExpressionMid<Fr> = ExpressionMid::Var(VarMid::Query(QueryMid::new(
Any::Advice,
1,
Rotation::cur(),
)));
let o: ExpressionMid<Fr> = ExpressionMid::Var(VarMid::Query(QueryMid::new(
Any::Advice,
2,
Rotation::cur(),
)));
let c: ExpressionMid<Fr> =
ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 0, Rotation::cur())));
let sl: ExpressionMid<Fr> =
ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 1, Rotation::cur())));
let sr: ExpressionMid<Fr> =
ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 2, Rotation::cur())));
let sm: ExpressionMid<Fr> =
ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 3, Rotation::cur())));
let so: ExpressionMid<Fr> =
ExpressionMid::Var(VarMid::Query(QueryMid::new(Any::Fixed, 4, Rotation::cur())));

let simple_plonk_expr = sl * l.clone() + sr * r.clone() + sm * (l * r) - so * o + c;
assert_eq!(
simple_plonk_expr.identifier(),
"(((((f1*a0)+(f2*a1))+(f3*(a0*a1)))+(-(f4*a2)))+f0)"
);
}
}

0 comments on commit d685565

Please sign in to comment.