Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: permutation argument optimizations #10960

Merged
merged 29 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
64317bc
basic computations skipping model working in GP comp
ledwards2225 Dec 23, 2024
0ab2078
WiP things working before switch to active idxs loops
ledwards2225 Dec 23, 2024
35d1c8d
using active idxs for GP step 1
ledwards2225 Dec 23, 2024
418f7bc
fix
ledwards2225 Dec 23, 2024
09ef41a
dont do any computation for num/denom in inactive regions
ledwards2225 Dec 23, 2024
a258127
Merge branch 'master' into lde/perm_opt
ledwards2225 Dec 24, 2024
72b215f
fix build
ledwards2225 Dec 24, 2024
c36feaa
fix PG tests
ledwards2225 Dec 24, 2024
bd8a511
clean debug code to fix some tests
ledwards2225 Dec 24, 2024
dd50b47
Merge branch 'master' into lde/perm_opt
ledwards2225 Jan 2, 2025
d9b43bb
test revision of gp method
ledwards2225 Jan 2, 2025
deb1d65
remove debug code for ci
ledwards2225 Jan 2, 2025
339a5c9
optimized version seems to be working, cleanup needed
ledwards2225 Jan 3, 2025
5471504
some fixes, see what fails
ledwards2225 Jan 3, 2025
4d7ea03
fix for client ivc
ledwards2225 Jan 3, 2025
7f348aa
correct ivc structure
ledwards2225 Jan 3, 2025
9736925
Merge branch 'master' into lde/perm_opt
ledwards2225 Jan 5, 2025
8911c2f
some cleanup
ledwards2225 Jan 6, 2025
46e378c
clean and regularize
ledwards2225 Jan 6, 2025
d0ea21a
fix index error in thread method
ledwards2225 Jan 6, 2025
45156f7
clarify and clean
ledwards2225 Jan 6, 2025
0ebd607
active regions model working
ledwards2225 Jan 6, 2025
b2539a7
Merge branch 'master' into lde/perm_opt
ledwards2225 Jan 6, 2025
bf1394d
more cleanup
ledwards2225 Jan 7, 2025
b104bc6
clean and remove debug code
ledwards2225 Jan 7, 2025
e702d00
remove problematic assert
ledwards2225 Jan 7, 2025
d34fe9c
clean
ledwards2225 Jan 8, 2025
cd6e7b6
Merge branch 'master' into lde/perm_opt
ledwards2225 Jan 8, 2025
e442a51
make active region class more robust and constify some things
ledwards2225 Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions barretenberg/cpp/src/barretenberg/common/thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ void parallel_for_heuristic(size_t num_points,
});
};

MultithreadData calculate_thread_data(size_t num_iterations, size_t min_iterations_per_thread)
{
size_t num_threads = calculate_num_threads(num_iterations, min_iterations_per_thread);
const size_t thread_size = num_iterations / num_threads;

// Cumpute the index bounds for each thread
std::vector<size_t> start(num_threads);
std::vector<size_t> end(num_threads);
for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
start[thread_idx] = thread_idx * thread_size;
end[thread_idx] = (thread_idx == num_threads - 1) ? num_iterations : (thread_idx + 1) * thread_size;
}

return MultithreadData{ num_threads, start, end };
}

/**
* @brief calculates number of threads to create based on minimum iterations per thread
* @details Finds the number of cpus with get_num_cpus(), and calculates `desired_num_threads`
Expand Down
18 changes: 18 additions & 0 deletions barretenberg/cpp/src/barretenberg/common/thread.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,24 @@ std::vector<Accum> parallel_for_heuristic(size_t num_points,

const size_t DEFAULT_MIN_ITERS_PER_THREAD = 1 << 4;

struct MultithreadData {
size_t num_threads;
// index bounds for each thread
std::vector<size_t> start;
std::vector<size_t> end;
};

/**
* @brief Calculates number of threads and index bounds for each thread
* @details Finds the number of cpus with calculate_num_threads() then divides domain evenly amongst threads
*
* @param num_iterations
* @param min_iterations_per_thread
* @return size_t
*/
MultithreadData calculate_thread_data(size_t num_iterations,
size_t min_iterations_per_thread = DEFAULT_MIN_ITERS_PER_THREAD);

/**
* @brief calculates number of threads to create based on minimum iterations per thread
* @details Finds the number of cpus with get_num_cpus(), and calculates `desired_num_threads`
Expand Down
16 changes: 14 additions & 2 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ class PrecomputedEntitiesBase {
uint64_t log_circuit_size;
uint64_t num_public_inputs;
};
// Specifies the regions of the execution trace containing non-trivial wire values
struct ActiveRegionData {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't these ranges need to be non-overlapping and in increasing order? maybe there should be a comment of some sort of check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a good point. I could add a check on add_range that the input has start >= the previous end. To be safe I suppose I'd also want to make the members private and add getters

std::vector<std::pair<size_t, size_t>> ranges; // active ranges [start_i, end_i) of the execution trace
std::vector<size_t> idxs; // full set of poly indices corresposponding to active ranges

void add_range(const size_t start, const size_t end)
{
ranges.emplace_back(start, end);
for (size_t i = start; i < end; ++i) {
idxs.push_back(i);
}
}
};

/**
* @brief Base proving key class.
Expand All @@ -123,8 +136,7 @@ template <typename FF, typename CommitmentKey_> class ProvingKey_ {
// folded element by element.
std::vector<FF> public_inputs;

// Ranges of the form [start, end) where witnesses have non-zero values (hence the execution trace is "active")
std::vector<std::pair<size_t, size_t>> active_block_ranges;
ActiveRegionData active_region_data; // specifies active regions of execution trace

ProvingKey_() = default;
ProvingKey_(const size_t dyadic_circuit_size,
Expand Down
1 change: 0 additions & 1 deletion barretenberg/cpp/src/barretenberg/goblin/mock_circuits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ class GoblinMockCircuits {
static void construct_simple_circuit(MegaBuilder& builder)
{
PROFILE_THIS();

add_some_ecc_op_gates(builder);
MockCircuits::construct_arithmetic_circuit(builder);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -224,33 +224,43 @@ void compute_honk_style_permutation_lagrange_polynomials_from_mapping(
using FF = typename Flavor::FF;
const size_t num_gates = proving_key->circuit_size;

size_t domain_size = proving_key->active_region_data.idxs.size();

const MultithreadData thread_data = calculate_thread_data(domain_size);

size_t wire_idx = 0;
for (auto& current_permutation_poly : permutation_polynomials) {
ITERATE_OVER_DOMAIN_START(proving_key->evaluation_domain);
auto idx = static_cast<ptrdiff_t>(i);
const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx];
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx];
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx];
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx];
if (current_is_public_input) {
// We intentionally want to break the cycles of the public input variables.
// During the witness generation, the left and right wire polynomials at idx i contain the i-th public
// input. The CyclicPermutation created for these variables always start with (i) -> (n+i), followed by
// the indices of the variables in the "real" gates. We make i point to -(i+1), so that the only way of
// repairing the cycle is add the mapping
// -(i+1) -> (n+i)
// These indices are chosen so they can easily be computed by the verifier. They can expect the running
// product to be equal to the "public input delta" that is computed in <honk/utils/grand_product_delta.hpp>
current_permutation_poly.at(i) = -FF(current_row_idx + 1 + num_gates * current_col_idx);
} else if (current_is_tag) {
// Set evaluations to (arbitrary) values disjoint from non-tag values
current_permutation_poly.at(i) = num_gates * Flavor::NUM_WIRES + current_row_idx;
} else {
// For the regular permutation we simply point to the next location by setting the evaluation to its
// idx
current_permutation_poly.at(i) = FF(current_row_idx + num_gates * current_col_idx);
}
ITERATE_OVER_DOMAIN_END;
parallel_for(thread_data.num_threads, [&](size_t j) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop now iterates over only the active domain instead of the entire poly domain. Prior to this change, the sigma/id polynomials took non-zero values across the entire domain. Now, they are non-zero only in the active regions of the trace and 0 elsewhere (previously we had sigma_i == id_i in these regions). These values don't contribute to the computation of the grand product anyway so there's no reason to compute them.

const size_t start = thread_data.start[j];
const size_t end = thread_data.end[j];
for (size_t i = start; i < end; ++i) {
size_t poly_idx = proving_key->active_region_data.idxs[i];
auto idx = static_cast<ptrdiff_t>(poly_idx);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this cast needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this can be a const and the one above too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast is needed since row_idx has type std::shared_ptr<uint32_t[]> which can only be indexed with a ptrdiff_t. (Note this isn't a change introduced in this PR)

const auto& current_row_idx = permutation_mappings[wire_idx].row_idx[idx];
const auto& current_col_idx = permutation_mappings[wire_idx].col_idx[idx];
const auto& current_is_tag = permutation_mappings[wire_idx].is_tag[idx];
const auto& current_is_public_input = permutation_mappings[wire_idx].is_public_input[idx];
if (current_is_public_input) {
// We intentionally want to break the cycles of the public input variables.
// During the witness generation, the left and right wire polynomials at idx i contain the i-th
// public input. The CyclicPermutation created for these variables always start with (i) -> (n+i),
// followed by the indices of the variables in the "real" gates. We make i point to
// -(i+1), so that the only way of repairing the cycle is add the mapping
// -(i+1) -> (n+i)
// These indices are chosen so they can easily be computed by the verifier. They can expect
// the running product to be equal to the "public input delta" that is computed
// in <honk/utils/grand_product_delta.hpp>
current_permutation_poly.at(poly_idx) = -FF(current_row_idx + 1 + num_gates * current_col_idx);
} else if (current_is_tag) {
// Set evaluations to (arbitrary) values disjoint from non-tag values
current_permutation_poly.at(poly_idx) = num_gates * Flavor::NUM_WIRES + current_row_idx;
} else {
// For the regular permutation we simply point to the next location by setting the
// evaluation to its idx
current_permutation_poly.at(poly_idx) = FF(current_row_idx + num_gates * current_col_idx);
}
}
});
wire_idx++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
#include "barretenberg/common/debug_log.hpp"
#include "barretenberg/common/thread.hpp"
#include "barretenberg/common/zip_view.hpp"
#include "barretenberg/flavor/flavor.hpp"
#include "barretenberg/plonk/proof_system/proving_key/proving_key.hpp"
#include "barretenberg/relations/relation_parameters.hpp"
#include "barretenberg/trace_to_polynomials/trace_to_polynomials.hpp"
#include <typeinfo>

namespace bb {
Expand Down Expand Up @@ -47,74 +49,68 @@ namespace bb {
*
* Note: Step (3) utilizes Montgomery batch inversion to replace n-many inversions with
*
* @note This method makes use of the fact that there are at most as many unique entries in the grand product as active
* rows in the execution trace to efficiently compute the grand product when a structured trace is in use. I.e. the
* computation peformed herein is proportional to the number of active rows in the trace and the constant values in the
* inactive regions are simply populated from known values on the last step.
*
* @tparam Flavor
* @tparam GrandProdRelation
* @param full_polynomials
* @param relation_parameters
* @param size_override optional size of the domain; otherwise based on dyadic polynomial domain
* @param active_region_data optional specification of active region of execution trace
*/
template <typename Flavor, typename GrandProdRelation>
void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials,
bb::RelationParameters<typename Flavor::FF>& relation_parameters,
size_t size_override = 0,
std::vector<std::pair<size_t, size_t>> active_block_ranges = {})
const ActiveRegionData& active_region_data = ActiveRegionData{})
{
PROFILE_THIS_NAME("compute_grand_product");

using FF = typename Flavor::FF;
using Polynomial = typename Flavor::Polynomial;
using Accumulator = std::tuple_element_t<0, typename GrandProdRelation::SumcheckArrayOfValuesOverSubrelations>;

const bool active_region_specified = !active_region_data.ranges.empty();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has_active_regions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hah I started with that but thought it was misleading because if false it implies that there are NO active regions when really its just that they are implicit and haven't been specified. You're probably right tho that has_active_regions is more clear


// Set the domain over which the grand product must be computed. This may be less than the dyadic circuit size, e.g
// the permutation grand product does not need to be computed beyond the index of the last active wire
size_t domain_size = size_override == 0 ? full_polynomials.get_polynomial_size() : size_override;

const size_t num_threads = domain_size >= get_num_cpus_pow2() ? get_num_cpus_pow2() : 1;
const size_t block_size = domain_size / num_threads;
const size_t final_idx = domain_size - 1;

// Cumpute the index bounds for each thread for reuse in the computations below
std::vector<std::pair<size_t, size_t>> idx_bounds;
idx_bounds.reserve(num_threads);
for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
const size_t start = thread_idx * block_size;
const size_t end = (thread_idx == num_threads - 1) ? final_idx : (thread_idx + 1) * block_size;
idx_bounds.push_back(std::make_pair(start, end));
}
// Returns the ith active index if specified, otherwise acts as the identity map on the input
auto get_active_range_poly_idx = [&](size_t i) { return active_region_specified ? active_region_data.idxs[i] : i; };

size_t active_domain_size = active_region_specified ? active_region_data.idxs.size() : domain_size;

// The size of the iteration domain is one less than the number of active rows since the final value of the
// grand product is constructed only in the relation and not explicitly in the polynomial
const MultithreadData active_range_thread_data = calculate_thread_data(active_domain_size - 1);

// Allocate numerator/denominator polynomials that will serve as scratch space
// TODO(zac) we can re-use the permutation polynomial as the numerator polynomial. Reduces readability
Polynomial numerator{ domain_size, domain_size };
Polynomial denominator{ domain_size, domain_size };

auto check_is_active = [&](size_t idx) {
if (active_block_ranges.empty()) {
return true;
}
return std::any_of(active_block_ranges.begin(), active_block_ranges.end(), [idx](const auto& range) {
return idx >= range.first && idx < range.second;
});
};
Polynomial numerator{ active_domain_size };
Polynomial denominator{ active_domain_size };

// Step (1)
// Populate `numerator` and `denominator` with the algebra described by Relation
FF gamma_fourth = relation_parameters.gamma.pow(4);
parallel_for(num_threads, [&](size_t thread_idx) {
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) {
const size_t start = active_range_thread_data.start[thread_idx];
const size_t end = active_range_thread_data.end[thread_idx];
typename Flavor::AllValues row;
const size_t start = idx_bounds[thread_idx].first;
const size_t end = idx_bounds[thread_idx].second;
for (size_t i = start; i < end; ++i) {
if (check_is_active(i)) {
// TODO(https://github.com/AztecProtocol/barretenberg/issues/940):consider avoiding get_row if possible.
row = full_polynomials.get_row(i);
numerator.at(i) =
GrandProdRelation::template compute_grand_product_numerator<Accumulator>(row, relation_parameters);
denominator.at(i) = GrandProdRelation::template compute_grand_product_denominator<Accumulator>(
row, relation_parameters);
// TODO(https://github.com/AztecProtocol/barretenberg/issues/940):consider avoiding get_row if possible.
auto row_idx = get_active_range_poly_idx(i);
if constexpr (IsUltraFlavor<Flavor>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if constexpr (!IsPlonkFlavor<Flavor>) would make this more readable I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not quite the same thing though because this code is also used by the ECCVM/Translator which need to be excluded. I think this just comes down to the fact that we need better concepts. Probably isUltraOrMegaHonk

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose I could just add methods to the ECCVM/Trans flavors that just call get_row from get_row_for_permutation_arg. Not sure what's better

row = full_polynomials.get_row_for_permutation_arg(row_idx);
} else {
numerator.at(i) = gamma_fourth;
denominator.at(i) = gamma_fourth;
row = full_polynomials.get_row(row_idx);
}
numerator.at(i) =
GrandProdRelation::template compute_grand_product_numerator<Accumulator>(row, relation_parameters);
denominator.at(i) =
GrandProdRelation::template compute_grand_product_denominator<Accumulator>(row, relation_parameters);
}
});

Expand All @@ -133,12 +129,12 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials,
// (ii) Take partial products P = { 1, a0a1, a2a3, a4a5 }
// (iii) Each thread j computes N[i][j]*P[j]=
// {{a0,a0a1},{a0a1a2,a0a1a2a3},{a0a1a2a3a4,a0a1a2a3a4a5},{a0a1a2a3a4a5a6,a0a1a2a3a4a5a6a7}}
std::vector<FF> partial_numerators(num_threads);
std::vector<FF> partial_denominators(num_threads);
std::vector<FF> partial_numerators(active_range_thread_data.num_threads);
std::vector<FF> partial_denominators(active_range_thread_data.num_threads);

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = idx_bounds[thread_idx].first;
const size_t end = idx_bounds[thread_idx].second;
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) {
const size_t start = active_range_thread_data.start[thread_idx];
const size_t end = active_range_thread_data.end[thread_idx];
for (size_t i = start; i < end - 1; ++i) {
numerator.at(i + 1) *= numerator[i];
denominator.at(i + 1) *= denominator[i];
Expand All @@ -150,9 +146,9 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials,
DEBUG_LOG_ALL(partial_numerators);
DEBUG_LOG_ALL(partial_denominators);

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = idx_bounds[thread_idx].first;
const size_t end = idx_bounds[thread_idx].second;
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) {
const size_t start = active_range_thread_data.start[thread_idx];
const size_t end = active_range_thread_data.end[thread_idx];
if (thread_idx > 0) {
FF numerator_scaling = 1;
FF denominator_scaling = 1;
Expand All @@ -179,14 +175,44 @@ void compute_grand_product(typename Flavor::ProverPolynomials& full_polynomials,
// We have a 'virtual' 0 at the start (as this is a to-be-shifted polynomial)
ASSERT(grand_product_polynomial.start_index() == 1);

parallel_for(num_threads, [&](size_t thread_idx) {
const size_t start = idx_bounds[thread_idx].first;
const size_t end = idx_bounds[thread_idx].second;
// For Ultra/Mega, the first row is an inactive zero row thus the grand prod takes value 1 at both i = 0 and i = 1
if constexpr (IsUltraFlavor<Flavor>) {
grand_product_polynomial.at(1) = 1;
}

// Compute grand product values corresponding only to the active regions of the trace
parallel_for(active_range_thread_data.num_threads, [&](size_t thread_idx) {
const size_t start = active_range_thread_data.start[thread_idx];
const size_t end = active_range_thread_data.end[thread_idx];
for (size_t i = start; i < end; ++i) {
grand_product_polynomial.at(i + 1) = numerator[i] * denominator[i];
auto poly_idx = get_active_range_poly_idx(i + 1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const size_t

grand_product_polynomial.at(poly_idx) = numerator[i] * denominator[i];
}
});

// Final step: If active/inactive regions have been specified, the value of the grand product in the inactive
// regions have not yet been set. The polynomial takes an already computed constant value across each inactive
// region (since no copy constraints are present there) equal to the value of the grand product at the first index
// of the subsequent active region.
if (active_region_specified) {
MultithreadData full_domain_thread_data = calculate_thread_data(domain_size);
parallel_for(full_domain_thread_data.num_threads, [&](size_t thread_idx) {
const size_t start = full_domain_thread_data.start[thread_idx];
const size_t end = full_domain_thread_data.end[thread_idx];
for (size_t i = start; i < end; ++i) {
for (size_t j = 0; j < active_region_data.ranges.size() - 1; ++j) {
size_t previous_range_end = active_region_data.ranges[j].second;
size_t next_range_start = active_region_data.ranges[j + 1].first;
// If the index falls in an inactive region, set its value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems incomplete

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha its not but I do see what you mean. I reordered to make it sound more natural

if (i >= previous_range_end && i < next_range_start) {
grand_product_polynomial.at(i) = grand_product_polynomial[next_range_start];
break;
}
}
}
});
}

DEBUG_LOG_ALL(grand_product_polynomial.coeffs());
}

Expand Down
Loading
Loading