diff --git a/barretenberg/cpp/src/barretenberg/bb/main.cpp b/barretenberg/cpp/src/barretenberg/bb/main.cpp index 5b8bc5041fd..be73ea4e548 100644 --- a/barretenberg/cpp/src/barretenberg/bb/main.cpp +++ b/barretenberg/cpp/src/barretenberg/bb/main.cpp @@ -588,27 +588,27 @@ void vk_as_fields(const std::string& vk_path, const std::string& output_path) * - Filesystem: The proof and vk are written to the paths output_path/proof and output_path/{vk, vk_fields.json} * * @param bytecode_path Path to the file containing the serialised bytecode - * @param calldata_path Path to the file containing the serialised calldata (could be empty) * @param public_inputs_path Path to the file containing the serialised avm public inputs * @param hints_path Path to the file containing the serialised avm circuit hints * @param output_path Path (directory) to write the output proof and verification keys */ -void avm_prove(const std::filesystem::path& calldata_path, - const std::filesystem::path& public_inputs_path, +void avm_prove(const std::filesystem::path& public_inputs_path, const std::filesystem::path& hints_path, const std::filesystem::path& output_path) { - std::vector const calldata = many_from_buffer(read_file(calldata_path)); - auto const avm_new_public_inputs = AvmPublicInputs::from(read_file(public_inputs_path)); + + auto const avm_public_inputs = AvmPublicInputs::from(read_file(public_inputs_path)); auto const avm_hints = bb::avm_trace::ExecutionHints::from(read_file(hints_path)); // Using [0] is fine now for the top-level call, but we might need to index by address in future vinfo("bytecode size: ", avm_hints.all_contract_bytecode[0].bytecode.size()); - vinfo("calldata size: ", calldata.size()); - vinfo("hints.storage_value_hints size: ", avm_hints.storage_value_hints.size()); - vinfo("hints.note_hash_exists_hints size: ", avm_hints.note_hash_exists_hints.size()); - vinfo("hints.nullifier_exists_hints size: ", avm_hints.nullifier_exists_hints.size()); - vinfo("hints.l1_to_l2_message_exists_hints size: ", avm_hints.l1_to_l2_message_exists_hints.size()); + vinfo("hints.storage_read_hints size: ", avm_hints.storage_read_hints.size()); + vinfo("hints.storage_write_hints size: ", avm_hints.storage_write_hints.size()); + vinfo("hints.nullifier_read_hints size: ", avm_hints.nullifier_read_hints.size()); + vinfo("hints.nullifier_write_hints size: ", avm_hints.nullifier_write_hints.size()); + vinfo("hints.note_hash_read_hints size: ", avm_hints.note_hash_read_hints.size()); + vinfo("hints.note_hash_write_hints size: ", avm_hints.note_hash_write_hints.size()); + vinfo("hints.l1_to_l2_message_read_hints size: ", avm_hints.l1_to_l2_message_read_hints.size()); vinfo("hints.externalcall_hints size: ", avm_hints.externalcall_hints.size()); vinfo("hints.contract_instance_hints size: ", avm_hints.contract_instance_hints.size()); vinfo("hints.contract_bytecode_hints size: ", avm_hints.all_contract_bytecode.size()); @@ -618,7 +618,7 @@ void avm_prove(const std::filesystem::path& calldata_path, // Prove execution and return vk auto const [verification_key, proof] = - AVM_TRACK_TIME_V("prove/all", avm_trace::Execution::prove(calldata, avm_new_public_inputs, avm_hints)); + AVM_TRACK_TIME_V("prove/all", avm_trace::Execution::prove(avm_public_inputs, avm_hints)); std::vector vk_as_fields = verification_key.to_field_elements(); @@ -1243,7 +1243,6 @@ int main(int argc, char* argv[]) write_recursion_inputs_honk(bytecode_path, witness_path, output_path, recursive); #ifndef DISABLE_AZTEC_VM } else if (command == "avm_prove") { - std::filesystem::path avm_calldata_path = get_option(args, "--avm-calldata", "./target/avm_calldata.bin"); std::filesystem::path avm_public_inputs_path = get_option(args, "--avm-public-inputs", "./target/avm_public_inputs.bin"); std::filesystem::path avm_hints_path = get_option(args, "--avm-hints", "./target/avm_hints.bin"); @@ -1251,7 +1250,7 @@ int main(int argc, char* argv[]) std::filesystem::path output_path = get_option(args, "-o", "./proofs"); extern std::filesystem::path avm_dump_trace_path; avm_dump_trace_path = get_option(args, "--avm-dump-trace", ""); - avm_prove(avm_calldata_path, avm_public_inputs_path, avm_hints_path, output_path); + avm_prove(avm_public_inputs_path, avm_hints_path, output_path); } else if (command == "avm_verify") { return avm_verify(proof_path, vk_path) ? 0 : 1; #endif diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp index a5d6a718246..433fafbc7b6 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/tests/execution.test.cpp @@ -74,7 +74,7 @@ class AvmExecutionTests : public ::testing::Test { * @param bytecode * @return The trace as a vector of Row. */ - std::vector gen_trace_from_bytecode(const std::vector& bytecode) + std::vector gen_trace_from_bytecode(const std::vector& bytecode) const { std::vector calldata{}; std::vector returndata{}; @@ -89,9 +89,9 @@ class AvmExecutionTests : public ::testing::Test { static std::vector gen_trace(const std::vector& bytecode, const std::vector& calldata, - AvmPublicInputs& public_inputs, + AvmPublicInputs public_inputs, std::vector& returndata, - ExecutionHints& execution_hints) + ExecutionHints execution_hints) { auto [contract_class_id, contract_instance] = gen_test_contract_hint(bytecode); execution_hints.with_avm_contract_bytecode( @@ -99,7 +99,11 @@ class AvmExecutionTests : public ::testing::Test { // These are magic values because of how some tests work! Don't change them public_inputs.public_app_logic_call_requests[0].contract_address = contract_instance.address; - return Execution::gen_trace(calldata, public_inputs, returndata, execution_hints); + execution_hints.enqueued_call_hints.push_back({ + .contract_address = contract_instance.address, + .calldata = calldata, + }); + return Execution::gen_trace(public_inputs, returndata, execution_hints); } static std::tuple gen_test_contract_hint( diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp index e31d486e502..ca121ebefa2 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/errors.hpp @@ -6,6 +6,7 @@ namespace bb::avm_trace { enum class AvmError : uint32_t { NO_ERROR, + REVERT_OPCODE, INVALID_PROGRAM_COUNTER, INVALID_OPCODE, INVALID_TAG_VALUE, @@ -18,6 +19,7 @@ enum class AvmError : uint32_t { CONTRACT_INST_MEM_UNKNOWN, RADIX_OUT_OF_BOUNDS, DUPLICATE_NULLIFIER, + SIDE_EFFECT_LIMIT_REACHED, }; } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp index cab6b49ddb5..edb50a4e9cf 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.cpp @@ -38,8 +38,39 @@ using namespace bb; std::filesystem::path avm_dump_trace_path; namespace bb::avm_trace { + +std::string to_name(TxExecutionPhase phase) +{ + switch (phase) { + case TxExecutionPhase::SETUP: + return "SETUP"; + case TxExecutionPhase::APP_LOGIC: + return "APP_LOGIC"; + case TxExecutionPhase::TEARDOWN: + return "TEARDOWN"; + default: + throw std::runtime_error("Invalid tx phase"); + break; + } +} + +/************************************************************************************************** + * HELPERS IN ANONYMOUS NAMESPACE + **************************************************************************************************/ namespace { +template +std::vector non_empty_call_requests(std::array call_requests_array) +{ + std::vector call_requests_vec; + for (const auto& call_request : call_requests_array) { + if (!call_request.is_empty()) { + call_requests_vec.push_back(call_request); + } + } + return call_requests_vec; +} + // The SRS needs to be able to accommodate the circuit subgroup size. // Note: The *2 is due to how init_bn254_crs works, look there. static_assert(Execution::SRS_SIZE >= AvmCircuitBuilder::CIRCUIT_SUBGROUP_SIZE * 2); @@ -147,13 +178,16 @@ void show_trace_info(const auto& trace) } // namespace +/************************************************************************************************** + * Execution + **************************************************************************************************/ + // Needed for dependency injection in tests. Execution::TraceBuilderConstructor Execution::trace_builder_constructor = [](AvmPublicInputs public_inputs, ExecutionHints execution_hints, uint32_t side_effect_counter, std::vector calldata) { - return AvmTraceBuilder( - std::move(public_inputs), std::move(execution_hints), side_effect_counter, std::move(calldata)); + return AvmTraceBuilder(public_inputs, std::move(execution_hints), side_effect_counter, std::move(calldata)); }; /** @@ -173,17 +207,19 @@ std::vector Execution::getDefaultPublicInputs() * of the execution of the supplied bytecode. * * @param bytecode A vector of bytes representing the bytecode to execute. - * @param calldata expressed as a vector of finite field elements. * @throws runtime_error exception when the bytecode is invalid. * @return The verifier key and zk proof of the execution. */ -std::tuple Execution::prove(std::vector const& calldata, - AvmPublicInputs const& public_inputs, +std::tuple Execution::prove(AvmPublicInputs const& public_inputs, ExecutionHints const& execution_hints) { std::vector returndata; - std::vector trace = - AVM_TRACK_TIME_V("prove/gen_trace", gen_trace(calldata, public_inputs, returndata, execution_hints)); + std::vector calldata; + for (const auto& enqueued_call_hints : execution_hints.enqueued_call_hints) { + calldata.insert(calldata.end(), enqueued_call_hints.calldata.begin(), enqueued_call_hints.calldata.end()); + } + std::vector trace = AVM_TRACK_TIME_V( + "prove/gen_trace", gen_trace(public_inputs, returndata, execution_hints, /*apply_end_gas_assertions=*/true)); if (!avm_dump_trace_path.empty()) { info("Dumping trace as CSV to: " + avm_dump_trace_path.string()); dump_trace_as_csv(trace, avm_dump_trace_path); @@ -258,585 +294,637 @@ bool Execution::verify(AvmFlavor::VerificationKey vk, HonkProof const& proof) /** * @brief Generate the execution trace pertaining to the supplied instructions returns the return data. * - * @param instructions A vector of the instructions to be executed. - * @param calldata expressed as a vector of finite field elements. - * @param public_inputs expressed as a vector of finite field elements. + * @param public_inputs - to constrain execution inputs & results against + * @param returndata - to add to for each enqueued call + * @param execution_hints - to inform execution + * @param apply_end_gas_assertions - should we apply assertions that public input's end gas is right? * @return The trace as a vector of Row. */ -std::vector Execution::gen_trace(std::vector const& calldata, - AvmPublicInputs const& public_inputs, +std::vector Execution::gen_trace(AvmPublicInputs const& public_inputs, std::vector& returndata, - ExecutionHints const& execution_hints) + ExecutionHints const& execution_hints, + bool apply_end_gas_assertions) { vinfo("------- GENERATING TRACE -------"); // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6718): construction of the public input columns // should be done in the kernel - this is stubbed and underconstrained // VmPublicInputs public_inputs = avm_trace::convert_public_inputs(public_inputs_vec); - uint32_t start_side_effect_counter = - 0; // What to do here??? - // !public_inputs_vec.empty() ? - // static_cast(public_inputs_vec[START_SIDE_EFFECT_COUNTER_PCPI_OFFSET]) - // : 0; - // + uint32_t start_side_effect_counter = 0; + // Temporary until we get proper nested call handling + std::vector calldata; + for (const auto& enqueued_call_hints : execution_hints.enqueued_call_hints) { + calldata.insert(calldata.end(), enqueued_call_hints.calldata.begin(), enqueued_call_hints.calldata.end()); + } AvmTraceBuilder trace_builder = Execution::trace_builder_constructor(public_inputs, execution_hints, start_side_effect_counter, calldata); - std::vector public_call_requests; - for (const auto& setup_requests : public_inputs.public_setup_call_requests) { - if (setup_requests.contract_address != 0) { - public_call_requests.push_back(setup_requests); - } - } - for (const auto& app_requests : public_inputs.public_app_logic_call_requests) { - if (app_requests.contract_address != 0) { - public_call_requests.push_back(app_requests); - } + const auto setup_call_requests = non_empty_call_requests(public_inputs.public_setup_call_requests); + const auto app_logic_call_requests = non_empty_call_requests(public_inputs.public_app_logic_call_requests); + std::vector teardown_call_requests; + if (!public_inputs.public_teardown_call_request.is_empty()) { + // teardown is always one call request + teardown_call_requests.push_back(public_inputs.public_teardown_call_request); } - // We should not need to guard teardown, but while we are testing with handcrafted txs we do - if (public_inputs.public_teardown_call_request.contract_address != 0) { - public_call_requests.push_back(public_inputs.public_teardown_call_request); - } - - // We should use the public input address, but for now we just take the first element in the list - // const std::vector& bytecode = execution_hints.all_contract_bytecode.at(0).bytecode; // Loop over all the public call requests uint8_t call_ctx = 0; - for (const auto& public_call_request : public_call_requests) { - trace_builder.set_public_call_request(public_call_request); - trace_builder.set_call_ptr(call_ctx++); - - // Find the bytecode based on contract address of the public call request - const std::vector& bytecode = - std::ranges::find_if(execution_hints.all_contract_bytecode, [public_call_request](const auto& contract) { - return contract.contract_instance.address == public_call_request.contract_address; - })->bytecode; - info("Found bytecode for contract address: ", public_call_request.contract_address); - - // Set this also on nested call - - // Copied version of pc maintained in trace builder. The value of pc is evolving based - // on opcode logic and therefore is not maintained here. However, the next opcode in the execution - // is determined by this value which require read access to the code below. - uint32_t pc = 0; - uint32_t counter = 0; - AvmError error = AvmError::NO_ERROR; - while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) { - auto [inst, parse_error] = Deserialization::parse(bytecode, pc); - error = parse_error; - - if (!is_ok(error)) { + const auto phases = { TxExecutionPhase::SETUP, TxExecutionPhase::APP_LOGIC, TxExecutionPhase::TEARDOWN }; + for (auto phase : phases) { + const auto public_call_requests = phase == TxExecutionPhase::SETUP ? setup_call_requests + : phase == TxExecutionPhase::APP_LOGIC ? app_logic_call_requests + : teardown_call_requests; + + // When we get this, it means we have done our non-revertible setup phase + if (phase == TxExecutionPhase::SETUP) { + vinfo("Inserting non-revertible side effects from private before SETUP phase. Checkpointing trees."); + // Temporary spot for private non-revertible insertion + std::vector siloed_nullifiers; + siloed_nullifiers.insert( + siloed_nullifiers.end(), + public_inputs.previous_non_revertible_accumulated_data.nullifiers.begin(), + public_inputs.previous_non_revertible_accumulated_data.nullifiers.begin() + + public_inputs.previous_non_revertible_accumulated_data_array_lengths.nullifiers); + trace_builder.insert_private_state(siloed_nullifiers, {}); + trace_builder.checkpoint_non_revertible_state(); + } else if (phase == TxExecutionPhase::APP_LOGIC) { + vinfo("Inserting revertible side effects from private before APP_LOGIC phase"); + // Temporary spot for private revertible insertion + std::vector siloed_nullifiers; + siloed_nullifiers.insert(siloed_nullifiers.end(), + public_inputs.previous_revertible_accumulated_data.nullifiers.begin(), + public_inputs.previous_revertible_accumulated_data.nullifiers.begin() + + public_inputs.previous_revertible_accumulated_data_array_lengths.nullifiers); + trace_builder.insert_private_state(siloed_nullifiers, {}); + } + + vinfo("Beginning execution of phase ", to_name(phase), " (", public_call_requests.size(), " enqueued calls)."); + AvmError phase_error = AvmError::NO_ERROR; + for (auto public_call_request : public_call_requests) { + trace_builder.set_public_call_request(public_call_request); + trace_builder.set_call_ptr(call_ctx++); + // Execute! + phase_error = Execution::execute_enqueued_call(trace_builder, public_call_request, returndata); + + if (!is_ok(phase_error)) { + info("Phase ", to_name(phase), " reverted."); + // otherwise, reverting in a revertible phase rolls back state + vinfo("Rolling back tree roots to non-revertible checkpoint"); + trace_builder.rollback_to_non_revertible_checkpoint(); break; } + } - debug("[PC:" + std::to_string(pc) + "] [IC:" + std::to_string(counter++) + "] " + inst.to_string() + - " (gasLeft l2=" + std::to_string(trace_builder.get_l2_gas_left()) + ")"); - - switch (inst.op_code) { - // Compute - // Compute - Arithmetic - case OpCode::ADD_8: - error = trace_builder.op_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::ADD_8); - break; - case OpCode::ADD_16: - error = trace_builder.op_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::ADD_16); - break; - case OpCode::SUB_8: - error = trace_builder.op_sub(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SUB_8); - break; - case OpCode::SUB_16: - error = trace_builder.op_sub(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SUB_16); - break; - case OpCode::MUL_8: - error = trace_builder.op_mul(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::MUL_8); - break; - case OpCode::MUL_16: - error = trace_builder.op_mul(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::MUL_16); - break; - case OpCode::DIV_8: - error = trace_builder.op_div(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::DIV_8); - break; - case OpCode::DIV_16: - error = trace_builder.op_div(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::DIV_16); - break; - case OpCode::FDIV_8: - error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::FDIV_8); - break; - case OpCode::FDIV_16: - error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::FDIV_16); - break; - case OpCode::EQ_8: - error = trace_builder.op_eq(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::EQ_8); - break; - case OpCode::EQ_16: - error = trace_builder.op_eq(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::EQ_16); - break; - case OpCode::LT_8: - error = trace_builder.op_lt(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LT_8); - break; - case OpCode::LT_16: - error = trace_builder.op_lt(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LT_16); - break; - case OpCode::LTE_8: - error = trace_builder.op_lte(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LTE_8); - break; - case OpCode::LTE_16: - error = trace_builder.op_lte(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::LTE_16); - break; - case OpCode::AND_8: - error = trace_builder.op_and(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::AND_8); - break; - case OpCode::AND_16: - error = trace_builder.op_and(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::AND_16); - break; - case OpCode::OR_8: - error = trace_builder.op_or(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::OR_8); - break; - case OpCode::OR_16: - error = trace_builder.op_or(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::OR_16); - break; - case OpCode::XOR_8: - error = trace_builder.op_xor(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::XOR_8); - break; - case OpCode::XOR_16: - error = trace_builder.op_xor(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::XOR_16); - break; - case OpCode::NOT_8: - error = trace_builder.op_not(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::NOT_8); - break; - case OpCode::NOT_16: - error = trace_builder.op_not(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::NOT_16); - break; - case OpCode::SHL_8: - error = trace_builder.op_shl(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHL_8); - break; - case OpCode::SHL_16: - error = trace_builder.op_shl(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHL_16); - break; - case OpCode::SHR_8: - error = trace_builder.op_shr(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHR_8); - break; - case OpCode::SHR_16: - error = trace_builder.op_shr(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::SHR_16); - break; + if (!is_ok(phase_error) && phase == TxExecutionPhase::SETUP) { + // Stop processing phases. Halt TX. + info("A revert during SETUP phase halts the entire TX"); + break; + } + } + auto trace = trace_builder.finalize(apply_end_gas_assertions); - // Compute - Type Conversions - case OpCode::CAST_8: - error = trace_builder.op_cast(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::CAST_8); - break; - case OpCode::CAST_16: - error = trace_builder.op_cast(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - OpCode::CAST_16); - break; + show_trace_info(trace); + return trace; +} - // Execution Environment - // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6284): support indirect for below - case OpCode::GETENVVAR_16: - error = trace_builder.op_get_env_var(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; +/** + * @brief Execute one enqueued call, adding its results to the trace. + * + * @param trace_builder - the trace builder to add rows to + * @param public_call_request - the enqueued call to execute + * @param returndata - to add to for each enqueued call + * @returns the error/result of the enqueued call + * + */ +AvmError Execution::execute_enqueued_call(AvmTraceBuilder& trace_builder, + PublicCallRequest& public_call_request, + std::vector& returndata) +{ + AvmError error = AvmError::NO_ERROR; + // Find the bytecode based on contract address of the public call request + std::vector bytecode = trace_builder.get_bytecode(public_call_request.contract_address); - // Execution Environment - Calldata - case OpCode::CALLDATACOPY: - error = trace_builder.op_calldata_copy(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; + // Set this also on nested call - case OpCode::RETURNDATASIZE: - error = trace_builder.op_returndata_size(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; + // Copied version of pc maintained in trace builder. The value of pc is evolving based + // on opcode logic and therefore is not maintained here. However, the next opcode in the execution + // is determined by this value which require read access to the code below. + uint32_t pc = 0; + uint32_t counter = 0; + while (is_ok(error) && (pc = trace_builder.get_pc()) < bytecode.size()) { + auto [inst, parse_error] = Deserialization::parse(bytecode, pc); + error = parse_error; + + if (!is_ok(error)) { + break; + } - case OpCode::RETURNDATACOPY: - error = trace_builder.op_returndata_copy(std::get(inst.operands.at(0)), + debug("[PC:" + std::to_string(pc) + "] [IC:" + std::to_string(counter++) + "] " + inst.to_string() + + " (gasLeft l2=" + std::to_string(trace_builder.get_l2_gas_left()) + ")"); + + switch (inst.op_code) { + // Compute + // Compute - Arithmetic + case OpCode::ADD_8: + error = trace_builder.op_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::ADD_8); + break; + case OpCode::ADD_16: + error = trace_builder.op_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::ADD_16); + break; + case OpCode::SUB_8: + error = trace_builder.op_sub(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SUB_8); + break; + case OpCode::SUB_16: + error = trace_builder.op_sub(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SUB_16); + break; + case OpCode::MUL_8: + error = trace_builder.op_mul(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::MUL_8); + break; + case OpCode::MUL_16: + error = trace_builder.op_mul(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::MUL_16); + break; + case OpCode::DIV_8: + error = trace_builder.op_div(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::DIV_8); + break; + case OpCode::DIV_16: + error = trace_builder.op_div(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::DIV_16); + break; + case OpCode::FDIV_8: + error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::FDIV_8); + break; + case OpCode::FDIV_16: + error = trace_builder.op_fdiv(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::FDIV_16); + break; + case OpCode::EQ_8: + error = trace_builder.op_eq(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::EQ_8); + break; + case OpCode::EQ_16: + error = trace_builder.op_eq(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::EQ_16); + break; + case OpCode::LT_8: + error = trace_builder.op_lt(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LT_8); + break; + case OpCode::LT_16: + error = trace_builder.op_lt(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LT_16); + break; + case OpCode::LTE_8: + error = trace_builder.op_lte(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LTE_8); + break; + case OpCode::LTE_16: + error = trace_builder.op_lte(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::LTE_16); + break; + case OpCode::AND_8: + error = trace_builder.op_and(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::AND_8); + break; + case OpCode::AND_16: + error = trace_builder.op_and(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::AND_16); + break; + case OpCode::OR_8: + error = trace_builder.op_or(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::OR_8); + break; + case OpCode::OR_16: + error = trace_builder.op_or(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::OR_16); + break; + case OpCode::XOR_8: + error = trace_builder.op_xor(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::XOR_8); + break; + case OpCode::XOR_16: + error = trace_builder.op_xor(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::XOR_16); + break; + case OpCode::NOT_8: + error = trace_builder.op_not(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::NOT_8); + break; + case OpCode::NOT_16: + error = trace_builder.op_not(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::NOT_16); + break; + case OpCode::SHL_8: + error = trace_builder.op_shl(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHL_8); + break; + case OpCode::SHL_16: + error = trace_builder.op_shl(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHL_16); + break; + case OpCode::SHR_8: + error = trace_builder.op_shr(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHR_8); + break; + case OpCode::SHR_16: + error = trace_builder.op_shr(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::SHR_16); + break; + + // Compute - Type Conversions + case OpCode::CAST_8: + error = trace_builder.op_cast(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::CAST_8); + break; + case OpCode::CAST_16: + error = trace_builder.op_cast(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + OpCode::CAST_16); + break; + + // Execution Environment + // TODO(https://github.com/AztecProtocol/aztec-packages/issues/6284): support indirect for below + case OpCode::GETENVVAR_16: + error = trace_builder.op_get_env_var(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + + // Execution Environment - Calldata + case OpCode::CALLDATACOPY: + error = trace_builder.op_calldata_copy(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + + case OpCode::RETURNDATASIZE: + error = trace_builder.op_returndata_size(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; + + case OpCode::RETURNDATACOPY: + error = trace_builder.op_returndata_copy(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + + // Machine State - Internal Control Flow + case OpCode::JUMP_32: + error = trace_builder.op_jump(std::get(inst.operands.at(0))); + break; + case OpCode::JUMPI_32: + error = trace_builder.op_jumpi(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::INTERNALCALL: + error = trace_builder.op_internal_call(std::get(inst.operands.at(0))); + break; + case OpCode::INTERNALRETURN: + error = trace_builder.op_internal_return(); + break; + + // Machine State - Memory + case OpCode::SET_8: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_8); + break; + } + case OpCode::SET_16: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_16); + break; + } + case OpCode::SET_32: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_32); + break; + } + case OpCode::SET_64: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_64); + break; + } + case OpCode::SET_128: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + uint256_t::from_uint128(std::get(inst.operands.at(3))), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_128); + break; + } + case OpCode::SET_FF: { + error = trace_builder.op_set(std::get(inst.operands.at(0)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::SET_FF); + break; + } + case OpCode::MOV_8: + error = trace_builder.op_mov(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::MOV_8); + break; + case OpCode::MOV_16: + error = trace_builder.op_mov(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + OpCode::MOV_16); + break; + + // World State + case OpCode::SLOAD: + error = trace_builder.op_sload(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::SSTORE: + error = trace_builder.op_sstore(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::NOTEHASHEXISTS: + error = trace_builder.op_note_hash_exists(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + case OpCode::EMITNOTEHASH: + error = trace_builder.op_emit_note_hash(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; + case OpCode::NULLIFIEREXISTS: + error = trace_builder.op_nullifier_exists(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; + case OpCode::EMITNULLIFIER: + error = trace_builder.op_emit_nullifier(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1))); + break; + + case OpCode::L1TOL2MSGEXISTS: + error = trace_builder.op_l1_to_l2_msg_exists(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2)), std::get(inst.operands.at(3))); - break; - - // Machine State - Internal Control Flow - case OpCode::JUMP_32: - error = trace_builder.op_jump(std::get(inst.operands.at(0))); - break; - case OpCode::JUMPI_32: - error = trace_builder.op_jumpi(std::get(inst.operands.at(0)), + break; + case OpCode::GETCONTRACTINSTANCE: + error = trace_builder.op_get_contract_instance(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; + + // Accrued Substate + case OpCode::EMITUNENCRYPTEDLOG: + error = trace_builder.op_emit_unencrypted_log(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + case OpCode::SENDL2TOL1MSG: + error = trace_builder.op_emit_l2_to_l1_msg(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + break; + + // Control Flow - Contract Calls + case OpCode::CALL: + error = trace_builder.op_call(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + case OpCode::STATICCALL: + error = trace_builder.op_static_call(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + case OpCode::RETURN: { + auto ret = trace_builder.op_return(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::INTERNALCALL: - error = trace_builder.op_internal_call(std::get(inst.operands.at(0))); - break; - case OpCode::INTERNALRETURN: - error = trace_builder.op_internal_return(); - break; - - // Machine State - Memory - case OpCode::SET_8: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_8); - break; - } - case OpCode::SET_16: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_16); - break; - } - case OpCode::SET_32: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_32); - break; - } - case OpCode::SET_64: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_64); - break; - } - case OpCode::SET_128: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - uint256_t::from_uint128(std::get(inst.operands.at(3))), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_128); - break; - } - case OpCode::SET_FF: { - error = trace_builder.op_set(std::get(inst.operands.at(0)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::SET_FF); - break; - } - case OpCode::MOV_8: - error = trace_builder.op_mov(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::MOV_8); - break; - case OpCode::MOV_16: - error = trace_builder.op_mov(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - OpCode::MOV_16); - break; + std::get(inst.operands.at(2))); + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - // World State - case OpCode::SLOAD: - error = trace_builder.op_sload(std::get(inst.operands.at(0)), + break; + } + case OpCode::REVERT_8: { + info("HIT REVERT_8 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); + auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); + + break; + } + case OpCode::REVERT_16: { + info("HIT REVERT_16 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); + auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2))); - break; - case OpCode::SSTORE: - error = trace_builder.op_sstore(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::NOTEHASHEXISTS: - error = trace_builder.op_note_hash_exists(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - case OpCode::EMITNOTEHASH: - error = trace_builder.op_emit_note_hash(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; - case OpCode::NULLIFIEREXISTS: - error = trace_builder.op_nullifier_exists(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - case OpCode::EMITNULLIFIER: - error = trace_builder.op_emit_nullifier(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1))); - break; + error = ret.error; + returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - case OpCode::L1TOL2MSGEXISTS: - error = trace_builder.op_l1_to_l2_msg_exists(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - case OpCode::GETCONTRACTINSTANCE: - error = trace_builder.op_get_contract_instance(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; + break; + } - // Accrued Substate - case OpCode::EMITUNENCRYPTEDLOG: - error = trace_builder.op_emit_unencrypted_log(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - break; - case OpCode::SENDL2TOL1MSG: - error = trace_builder.op_emit_l2_to_l1_msg(std::get(inst.operands.at(0)), + // Misc + case OpCode::DEBUGLOG: + error = trace_builder.op_debug_log(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; + + // Gadgets + case OpCode::POSEIDON2PERM: + error = trace_builder.op_poseidon2_permutation(std::get(inst.operands.at(0)), std::get(inst.operands.at(1)), std::get(inst.operands.at(2))); - break; - // Control Flow - Contract Calls - case OpCode::CALL: - error = trace_builder.op_call(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; - case OpCode::STATICCALL: - error = trace_builder.op_static_call(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; - case OpCode::RETURN: { - auto ret = trace_builder.op_return(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - case OpCode::REVERT_8: { - info("HIT REVERT_8 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); - auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - case OpCode::REVERT_16: { - info("HIT REVERT_16 ", "[PC=" + std::to_string(pc) + "] " + inst.to_string()); - auto ret = trace_builder.op_revert(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - error = ret.error; - returndata.insert(returndata.end(), ret.return_data.begin(), ret.return_data.end()); - - break; - } - - // Misc - case OpCode::DEBUGLOG: - error = trace_builder.op_debug_log(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; - - // Gadgets - case OpCode::POSEIDON2PERM: - error = trace_builder.op_poseidon2_permutation(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); - - break; - - case OpCode::SHA256COMPRESSION: - error = trace_builder.op_sha256_compression(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3))); - break; - - case OpCode::KECCAKF1600: - error = trace_builder.op_keccakf1600(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2))); + break; - break; + case OpCode::SHA256COMPRESSION: + error = trace_builder.op_sha256_compression(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3))); + break; - case OpCode::ECADD: - error = trace_builder.op_ec_add(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5)), - std::get(inst.operands.at(6)), - std::get(inst.operands.at(7))); - break; - case OpCode::MSM: - error = trace_builder.op_variable_msm(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4))); - break; + case OpCode::KECCAKF1600: + error = trace_builder.op_keccakf1600(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2))); - // Conversions - case OpCode::TORADIXBE: - error = trace_builder.op_to_radix_be(std::get(inst.operands.at(0)), - std::get(inst.operands.at(1)), - std::get(inst.operands.at(2)), - std::get(inst.operands.at(3)), - std::get(inst.operands.at(4)), - std::get(inst.operands.at(5))); - break; + break; - default: - throw_or_abort("Don't know how to execute opcode " + to_hex(inst.op_code) + " at pc " + - std::to_string(pc) + "."); - break; - } - } - - if (!is_ok(error)) { - info("AVM stopped due to exceptional halting condition. Error: ", - to_name(error), - " at PC: ", - pc, - " IC: ", - counter - 1); // Need adjustement as counter increment occurs in loop body + case OpCode::ECADD: + error = trace_builder.op_ec_add(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5)), + std::get(inst.operands.at(6)), + std::get(inst.operands.at(7))); + break; + case OpCode::MSM: + error = trace_builder.op_variable_msm(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4))); + break; + + // Conversions + case OpCode::TORADIXBE: + error = trace_builder.op_to_radix_be(std::get(inst.operands.at(0)), + std::get(inst.operands.at(1)), + std::get(inst.operands.at(2)), + std::get(inst.operands.at(3)), + std::get(inst.operands.at(4)), + std::get(inst.operands.at(5))); + break; + + default: + throw_or_abort("Don't know how to execute opcode " + to_hex(inst.op_code) + " at pc " + std::to_string(pc) + + "."); + break; } } - auto trace = trace_builder.finalize(); - - show_trace_info(trace); - return trace; + if (!is_ok(error)) { + auto const error_ic = counter - 1; // Need adjustement as counter increment occurs in loop body + std::string reason_prefix = exceptionally_halted(error) ? "exceptional halt" : "REVERT opcode"; + info("AVM enqueued call halted due to ", + reason_prefix, + ". Error: ", + to_name(error), + " at PC: ", + pc, + " IC: ", + error_ic); + } + return error; } } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp index a9f3ad6d695..84003f38ecc 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution.hpp @@ -13,6 +13,14 @@ namespace bb::avm_trace { +enum class TxExecutionPhase : uint32_t { + SETUP, + APP_LOGIC, + TEARDOWN, +}; + +std::string to_name(TxExecutionPhase phase); + class Execution { public: static constexpr size_t SRS_SIZE = 1 << 22; @@ -29,10 +37,14 @@ class Execution { // Bytecode is currently the bytecode of the top-level function call // Eventually this will be the bytecode of the dispatch function of top-level contract - static std::vector gen_trace(std::vector const& calldata, - AvmPublicInputs const& new_public_inputs, + static std::vector gen_trace(AvmPublicInputs const& public_inputs, std::vector& returndata, - ExecutionHints const& execution_hints); + ExecutionHints const& execution_hints, + bool apply_end_gas_assertions = false); + + static AvmError execute_enqueued_call(AvmTraceBuilder& trace_builder, + PublicCallRequest& public_call_request, + std::vector& returndata); // For testing purposes only. static void set_trace_builder_constructor(TraceBuilderConstructor constructor) @@ -41,9 +53,7 @@ class Execution { } static std::tuple prove( - std::vector const& calldata = {}, - AvmPublicInputs const& public_inputs = AvmPublicInputs(), - ExecutionHints const& execution_hints = {}); + AvmPublicInputs const& public_inputs = AvmPublicInputs(), ExecutionHints const& execution_hints = {}); static bool verify(AvmFlavor::VerificationKey vk, HonkProof const& proof); private: diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution_hints.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution_hints.hpp index c2905791119..893bc53eba7 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution_hints.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/execution_hints.hpp @@ -217,7 +217,20 @@ inline void read(uint8_t const*& it, AvmContractBytecode& bytecode) read(it, bytecode.contract_class_id_preimage); } +struct AvmEnqueuedCallHint { + FF contract_address; + std::vector calldata; +}; + +inline void read(uint8_t const*& it, AvmEnqueuedCallHint& hint) +{ + using serialize::read; + read(it, hint.contract_address); + read(it, hint.calldata); +} + struct ExecutionHints { + std::vector enqueued_call_hints; std::vector> storage_value_hints; std::vector> note_hash_exists_hints; std::vector> nullifier_exists_hints; @@ -309,6 +322,9 @@ struct ExecutionHints { using serialize::read; const auto* it = data.data(); + std::vector enqueued_call_hints; + read(it, enqueued_call_hints); + read(it, storage_value_hints); read(it, note_hash_exists_hints); read(it, nullifier_exists_hints); @@ -353,19 +369,28 @@ struct ExecutionHints { " bytes out of " + std::to_string(data.size()) + " bytes"); } - return { std::move(storage_value_hints), std::move(note_hash_exists_hints), - std::move(nullifier_exists_hints), std::move(l1_to_l2_message_exists_hints), - std::move(externalcall_hints), std::move(contract_instance_hints), - std::move(all_contract_bytecode), std::move(storage_read_hints), - std::move(storage_write_hints), std::move(nullifier_read_hints), - std::move(nullifier_write_hints), std::move(note_hash_read_hints), - std::move(note_hash_write_hints), std::move(l1_to_l2_message_read_hints) + return { std::move(enqueued_call_hints), + std::move(storage_value_hints), + std::move(note_hash_exists_hints), + std::move(nullifier_exists_hints), + std::move(l1_to_l2_message_exists_hints), + std::move(externalcall_hints), + std::move(contract_instance_hints), + std::move(all_contract_bytecode), + std::move(storage_read_hints), + std::move(storage_write_hints), + std::move(nullifier_read_hints), + std::move(nullifier_write_hints), + std::move(note_hash_read_hints), + std::move(note_hash_write_hints), + std::move(l1_to_l2_message_read_hints) }; } private: - ExecutionHints(std::vector> storage_value_hints, + ExecutionHints(std::vector enqueued_call_hints, + std::vector> storage_value_hints, std::vector> note_hash_exists_hints, std::vector> nullifier_exists_hints, std::vector> l1_to_l2_message_exists_hints, @@ -380,7 +405,8 @@ struct ExecutionHints { std::vector note_hash_write_hints, std::vector l1_to_l2_message_read_hints) - : storage_value_hints(std::move(storage_value_hints)) + : enqueued_call_hints(std::move(enqueued_call_hints)) + , storage_value_hints(std::move(storage_value_hints)) , note_hash_exists_hints(std::move(note_hash_exists_hints)) , nullifier_exists_hints(std::move(nullifier_exists_hints)) , l1_to_l2_message_exists_hints(std::move(l1_to_l2_message_exists_hints)) diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp index 32a1e7eec2b..87e46bf7128 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.cpp @@ -10,6 +10,15 @@ using Poseidon2 = crypto::Poseidon2; * UNCONSTRAINED TREE OPERATIONS **************************************************************************************************/ +void AvmMerkleTreeTraceBuilder::checkpoint_non_revertible_state() +{ + non_revertible_tree_snapshots = tree_snapshots.copy(); +} +void AvmMerkleTreeTraceBuilder::rollback_to_non_revertible_checkpoint() +{ + tree_snapshots = non_revertible_tree_snapshots; +} + FF AvmMerkleTreeTraceBuilder::unconstrained_hash_nullifier_preimage(const NullifierLeafPreimage& preimage) { return Poseidon2::hash({ preimage.nullifier, preimage.next_nullifier, preimage.next_index }); @@ -75,14 +84,13 @@ FF AvmMerkleTreeTraceBuilder::unconstrained_update_leaf_index(const FF& leaf_val bool AvmMerkleTreeTraceBuilder::perform_storage_read([[maybe_unused]] uint32_t clk, const PublicDataTreeLeafPreimage& preimage, const FF& leaf_index, - const std::vector& path, - const FF& root) + const std::vector& path) const { // Hash the preimage FF preimage_hash = unconstrained_hash_public_data_preimage(preimage); auto index = static_cast(leaf_index); // Check if the leaf is a member of the tree - return unconstrained_check_membership(preimage_hash, index, path, root); + return unconstrained_check_membership(preimage_hash, index, path, tree_snapshots.public_data_tree.root); } FF AvmMerkleTreeTraceBuilder::perform_storage_write([[maybe_unused]] uint32_t clk, @@ -91,19 +99,19 @@ FF AvmMerkleTreeTraceBuilder::perform_storage_write([[maybe_unused]] uint32_t cl const std::vector& low_path, const FF& slot, const FF& value, - const FF& insertion_index, - const std::vector& insertion_path, - const FF& initial_root) + const std::vector& insertion_path) { // Check membership of the low leaf - bool low_leaf_member = perform_storage_read(clk, low_preimage, low_index, low_path, initial_root); + bool low_leaf_member = perform_storage_read(clk, low_preimage, low_index, low_path); ASSERT(low_leaf_member); if (slot == low_preimage.slot) { // We update the low value low_preimage.value = value; FF low_preimage_hash = unconstrained_hash_public_data_preimage(low_preimage); // Update the low leaf - return unconstrained_update_leaf_index(low_preimage_hash, static_cast(low_index), low_path); + tree_snapshots.public_data_tree.root = + unconstrained_update_leaf_index(low_preimage_hash, static_cast(low_index), low_path); + return tree_snapshots.public_data_tree.root; } // The new leaf for an insertion is PublicDataTreeLeafPreimage new_preimage{ @@ -111,32 +119,34 @@ FF AvmMerkleTreeTraceBuilder::perform_storage_write([[maybe_unused]] uint32_t cl }; // Update the low preimage with the new leaf preimage low_preimage.next_slot = slot; - low_preimage.next_index = insertion_index; + low_preimage.next_index = tree_snapshots.public_data_tree.size; // Hash the low preimage FF low_preimage_hash = unconstrained_hash_public_data_preimage(low_preimage); // Compute the new root FF new_root = unconstrained_update_leaf_index(low_preimage_hash, static_cast(low_index), low_path); // Check membership of the zero leaf at the insertion index against the new root - auto index = static_cast(insertion_index); + auto index = static_cast(tree_snapshots.public_data_tree.size); bool zero_leaf_member = unconstrained_check_membership(FF::zero(), index, insertion_path, new_root); ASSERT(zero_leaf_member); // Hash the new preimage FF leaf_preimage_hash = unconstrained_hash_public_data_preimage(new_preimage); // Insert the new leaf into the tree - return unconstrained_update_leaf_index(leaf_preimage_hash, index, insertion_path); + tree_snapshots.public_data_tree.root = unconstrained_update_leaf_index(leaf_preimage_hash, index, insertion_path); + tree_snapshots.public_data_tree.size++; + return tree_snapshots.public_data_tree.root; } bool AvmMerkleTreeTraceBuilder::perform_nullifier_read([[maybe_unused]] uint32_t clk, const NullifierLeafPreimage& preimage, const FF& leaf_index, - const std::vector& path, - const FF& root) + const std::vector& path) const + { // Hash the preimage FF preimage_hash = unconstrained_hash_nullifier_preimage(preimage); auto index = static_cast(leaf_index); // Check if the leaf is a member of the tree - return unconstrained_check_membership(preimage_hash, index, path, root); + return unconstrained_check_membership(preimage_hash, index, path, tree_snapshots.nullifier_tree.root); } FF AvmMerkleTreeTraceBuilder::perform_nullifier_append([[maybe_unused]] uint32_t clk, @@ -144,22 +154,20 @@ FF AvmMerkleTreeTraceBuilder::perform_nullifier_append([[maybe_unused]] uint32_t const FF& low_index, const std::vector& low_path, const FF& nullifier, - const FF& insertion_index, - const std::vector& insertion_path, - const FF& root) + const std::vector& insertion_path) { bool is_update = low_preimage.nullifier == nullifier; FF low_preimage_hash = unconstrained_hash_nullifier_preimage(low_preimage); if (is_update) { // We need to raise an error here, since updates arent allowed in the nullifier tree - bool is_member = - unconstrained_check_membership(low_preimage_hash, static_cast(low_index), low_path, root); + bool is_member = unconstrained_check_membership( + low_preimage_hash, static_cast(low_index), low_path, tree_snapshots.nullifier_tree.root); ASSERT(is_member); - return root; + return tree_snapshots.nullifier_tree.root; } // Check membership of the low leaf - bool low_leaf_member = - unconstrained_check_membership(low_preimage_hash, static_cast(low_index), low_path, root); + bool low_leaf_member = unconstrained_check_membership( + low_preimage_hash, static_cast(low_index), low_path, tree_snapshots.nullifier_tree.root); ASSERT(low_leaf_member); // The new leaf for an insertion is NullifierLeafPreimage new_preimage{ .nullifier = nullifier, @@ -167,19 +175,52 @@ FF AvmMerkleTreeTraceBuilder::perform_nullifier_append([[maybe_unused]] uint32_t .next_index = low_preimage.next_index }; // Update the low preimage low_preimage.next_nullifier = nullifier; - low_preimage.next_index = insertion_index; + low_preimage.next_index = tree_snapshots.nullifier_tree.size; // Update hash of the low preimage low_preimage_hash = unconstrained_hash_nullifier_preimage(low_preimage); // Update the root with new low preimage FF updated_root = unconstrained_update_leaf_index(low_preimage_hash, static_cast(low_index), low_path); // Check membership of the zero leaf at the insertion index against the new root - auto index = static_cast(insertion_index); + auto index = static_cast(tree_snapshots.nullifier_tree.size); bool zero_leaf_member = unconstrained_check_membership(FF::zero(), index, insertion_path, updated_root); ASSERT(zero_leaf_member); // Hash the new preimage FF leaf_preimage_hash = unconstrained_hash_nullifier_preimage(new_preimage); // Insert the new leaf into the tree - return unconstrained_update_leaf_index(leaf_preimage_hash, index, insertion_path); + tree_snapshots.nullifier_tree.root = unconstrained_update_leaf_index(leaf_preimage_hash, index, insertion_path); + tree_snapshots.nullifier_tree.size++; + return tree_snapshots.nullifier_tree.root; +} + +bool AvmMerkleTreeTraceBuilder::perform_note_hash_read([[maybe_unused]] uint32_t clk, + const FF& note_hash, + const FF& leaf_index, + const std::vector& path) const +{ + auto index = static_cast(leaf_index); + return unconstrained_check_membership(note_hash, index, path, tree_snapshots.note_hash_tree.root); +} + +FF AvmMerkleTreeTraceBuilder::perform_note_hash_append([[maybe_unused]] uint32_t clk, + const FF& note_hash, + const std::vector& insertion_path) +{ + auto index = static_cast(tree_snapshots.note_hash_tree.size); + bool zero_leaf_member = + unconstrained_check_membership(FF::zero(), index, insertion_path, tree_snapshots.note_hash_tree.root); + ASSERT(zero_leaf_member); + tree_snapshots.note_hash_tree.root = unconstrained_update_leaf_index(note_hash, index, insertion_path); + tree_snapshots.note_hash_tree.size++; + return tree_snapshots.note_hash_tree.root; +} + +bool AvmMerkleTreeTraceBuilder::perform_l1_to_l2_message_read([[maybe_unused]] uint32_t clk, + const FF& leaf_value, + const FF leaf_index, + const std::vector& path) const +{ + auto index = static_cast(leaf_index); + return unconstrained_check_membership(leaf_value, index, path, tree_snapshots.l1_to_l2_message_tree.root); } /************************************************************************************************** diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp index 382c49942fa..7b67db89313 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/gadgets/merkle_tree.hpp @@ -22,8 +22,14 @@ class AvmMerkleTreeTraceBuilder { }; AvmMerkleTreeTraceBuilder() = default; + AvmMerkleTreeTraceBuilder(TreeSnapshots& tree_snapshots) + : tree_snapshots(tree_snapshots){}; + void reset(); + void checkpoint_non_revertible_state(); + void rollback_to_non_revertible_checkpoint(); + bool check_membership( uint32_t clk, const FF& leaf_value, const uint64_t leaf_index, const std::vector& path, const FF& root); @@ -35,12 +41,13 @@ class AvmMerkleTreeTraceBuilder { FF compute_public_tree_leaf_slot(uint32_t clk, FF contract_address, FF leaf_index); - // These can be static, but not yet in-case we want to store the tree snapshots in this gadget + TreeSnapshots& get_tree_snapshots() { return tree_snapshots; } + + // Public Data Tree bool perform_storage_read(uint32_t clk, const PublicDataTreeLeafPreimage& preimage, const FF& leaf_index, - const std::vector& path, - const FF& root); + const std::vector& path) const; FF perform_storage_write(uint32_t clk, PublicDataTreeLeafPreimage& low_preimage, @@ -48,24 +55,34 @@ class AvmMerkleTreeTraceBuilder { const std::vector& low_path, const FF& slot, const FF& value, - const FF& insertion_index, - const std::vector& insertion_path, - const FF& initial_root); + const std::vector& insertion_path); + // Nullifier Tree bool perform_nullifier_read(uint32_t clk, const NullifierLeafPreimage& preimage, const FF& leaf_index, - const std::vector& path, - const FF& root); + const std::vector& path) const; FF perform_nullifier_append(uint32_t clk, NullifierLeafPreimage& low_preimage, const FF& low_index, const std::vector& low_path, const FF& nullifier, - const FF& insertion_index, - const std::vector& insertion_path, - const FF& root); + const std::vector& insertion_path); + + // Note Hash Tree + bool perform_note_hash_read(uint32_t clk, + const FF& note_hash, + const FF& leaf_index, + const std::vector& path) const; + + FF perform_note_hash_append(uint32_t clk, const FF& note_hash, const std::vector& insertion_path); + + // L1 to L2 Message Tree + bool perform_l1_to_l2_message_read(uint32_t clk, + const FF& leaf_value, + const FF leaf_index, + const std::vector& path) const; // Unconstrained variants while circuit stuff is being worked out static bool unconstrained_check_membership(const FF& leaf_value, @@ -86,11 +103,14 @@ class AvmMerkleTreeTraceBuilder { static FF unconstrained_compute_public_tree_leaf_slot(FF contract_address, FF leaf_index); void finalize(std::vector>& main_trace); + // We need access to the poseidon2 gadget AvmPoseidon2TraceBuilder poseidon2_builder; private: std::vector merkle_check_trace; + TreeSnapshots non_revertible_tree_snapshots; + TreeSnapshots tree_snapshots; MerkleEntry compute_root_from_path(uint32_t clk, const FF& leaf_value, const uint64_t leaf_index, diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp index e40a90129d5..28a540f69aa 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.cpp @@ -105,6 +105,8 @@ std::string to_name(AvmError error) switch (error) { case AvmError::NO_ERROR: return "NO ERROR"; + case AvmError::REVERT_OPCODE: + return "REVERT OPCODE"; case AvmError::INVALID_PROGRAM_COUNTER: return "INVALID PROGRAM COUNTER"; case AvmError::INVALID_OPCODE: @@ -127,6 +129,10 @@ std::string to_name(AvmError error) return "CONTRACT INSTANCE MEMBER UNKNOWN"; case AvmError::RADIX_OUT_OF_BOUNDS: return "RADIX OUT OF BOUNDS"; + case AvmError::DUPLICATE_NULLIFIER: + return "DUPLICATE NULLIFIER"; + case AvmError::SIDE_EFFECT_LIMIT_REACHED: + return "SIDE EFFECT LIMIT REACHED"; default: throw std::runtime_error("Invalid error type"); break; @@ -138,6 +144,11 @@ bool is_ok(AvmError error) return error == AvmError::NO_ERROR; } +bool exceptionally_halted(AvmError error) +{ + return error != AvmError::NO_ERROR && error != AvmError::REVERT_OPCODE; +} + /** * * ONLY FOR TESTS - Required by dsl module and therefore cannot be moved to test/helpers.test.cpp diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp index 1f3b845c8e4..1ba7375276b 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/helper.hpp @@ -235,6 +235,7 @@ std::string to_name(bb::avm_trace::AvmMemoryTag tag); std::string to_name(AvmError error); bool is_ok(AvmError error); +bool exceptionally_halted(AvmError error); // Mutate the inputs void inject_end_gas_values(AvmPublicInputs& public_inputs, std::vector& trace); diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp index fb2f236aa6a..c68d5e363fd 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/public_inputs.hpp @@ -18,8 +18,8 @@ struct Gas { inline void read(uint8_t const*& it, Gas& gas) { using serialize::read; - read(it, gas.l2_gas); read(it, gas.da_gas); + read(it, gas.l2_gas); } struct GasFees { @@ -87,6 +87,7 @@ inline void read(uint8_t const*& it, GlobalVariables& global_variables) struct AppendOnlyTreeSnapshot { FF root{}; uint32_t size = 0; + inline bool operator==(const AppendOnlyTreeSnapshot& rhs) const { return root == rhs.root && size == rhs.size; } }; inline void read(uint8_t const*& it, AppendOnlyTreeSnapshot& tree_snapshot) @@ -101,6 +102,32 @@ struct TreeSnapshots { AppendOnlyTreeSnapshot note_hash_tree; AppendOnlyTreeSnapshot nullifier_tree; AppendOnlyTreeSnapshot public_data_tree; + inline bool operator==(const TreeSnapshots& rhs) const + { + return l1_to_l2_message_tree == rhs.l1_to_l2_message_tree && note_hash_tree == rhs.note_hash_tree && + nullifier_tree == rhs.nullifier_tree && public_data_tree == rhs.public_data_tree; + } + inline TreeSnapshots copy() + { + return { + .l1_to_l2_message_tree = { + .root = l1_to_l2_message_tree.root, + .size = l1_to_l2_message_tree.size, + }, + .note_hash_tree = { + .root = note_hash_tree.root, + .size = note_hash_tree.size, + }, + .nullifier_tree = { + .root = nullifier_tree.root, + .size = nullifier_tree.size, + }, + .public_data_tree = { + .root = public_data_tree.root, + .size = public_data_tree.size, + }, + }; + } }; inline void read(uint8_t const*& it, TreeSnapshots& tree_snapshots) @@ -130,6 +157,10 @@ struct PublicCallRequest { */ bool is_static_call = false; FF args_hash{}; + inline bool is_empty() const + { + return msg_sender == 0 && contract_address == 0 && function_selector == 0 && !is_static_call && args_hash == 0; + } }; inline void read(uint8_t const*& it, PublicCallRequest& public_call_request) @@ -155,9 +186,24 @@ inline void read(uint8_t const*& it, PrivateToAvmAccumulatedDataArrayLengths& le read(it, lengths.nullifiers); read(it, lengths.l2_to_l1_msgs); } +struct L2ToL1Message { + FF recipient{}; // This is an eth address so it's actually only 20 bytes + FF content{}; + uint32_t counter = 0; +}; + +inline void read(uint8_t const*& it, L2ToL1Message& l2_to_l1_message) +{ + using serialize::read; + std::array recipient; + read(it, recipient); + l2_to_l1_message.recipient = FF::serialize_from_buffer(recipient.data()); + read(it, l2_to_l1_message.content); + read(it, l2_to_l1_message.counter); +} struct ScopedL2ToL1Message { - FF l2_to_l1_message{}; + L2ToL1Message l2_to_l1_message{}; FF contract_address{}; }; @@ -170,27 +216,21 @@ inline void read(uint8_t const*& it, ScopedL2ToL1Message& l2_to_l1_message) struct PrivateToAvmAccumulatedData { std::array note_hashes{}; - std::array nullifiers{}; - std::array l2_to_l1_msgs; + std::array nullifiers{}; + std::array l2_to_l1_msgs; }; inline void read(uint8_t const*& it, PrivateToAvmAccumulatedData& accumulated_data) { using serialize::read; - for (size_t i = 0; i < MAX_NOTE_HASHES_PER_TX; i++) { - read(it, accumulated_data.note_hashes[i]); - } - for (size_t i = 0; i < MAX_NULLIFIERS_PER_CALL; i++) { - read(it, accumulated_data.nullifiers[i]); - } - for (size_t i = 0; i < MAX_L2_TO_L1_MSGS_PER_CALL; i++) { - read(it, accumulated_data.l2_to_l1_msgs[i]); - } + read(it, accumulated_data.note_hashes); + read(it, accumulated_data.nullifiers); + read(it, accumulated_data.l2_to_l1_msgs); } struct LogHash { FF value{}; - FF counter{}; + uint32_t counter = 0; FF length{}; }; @@ -234,39 +274,30 @@ struct AvmAccumulatedData { /** * The nullifiers from private combining with those made in the AVM execution. */ - std::array nullifiers{}; + std::array nullifiers{}; /** * The L2 to L1 messages from private combining with those made in the AVM execution. */ - std::array l2_to_l1_msgs; + std::array l2_to_l1_msgs{}; /** * The unencrypted logs emitted from the AVM execution. */ - std::array unencrypted_logs_hashes; + std::array unencrypted_logs_hashes{}; /** * The public data writes made in the AVM execution. */ - std::array public_data_writes; + std::array public_data_writes{}; }; inline void read(uint8_t const*& it, AvmAccumulatedData& accumulated_data) { using serialize::read; - for (size_t i = 0; i < MAX_NOTE_HASHES_PER_TX; i++) { - read(it, accumulated_data.note_hashes[i]); - } - for (size_t i = 0; i < MAX_NULLIFIERS_PER_CALL; i++) { - read(it, accumulated_data.nullifiers[i]); - } - for (size_t i = 0; i < MAX_L2_TO_L1_MSGS_PER_CALL; i++) { - read(it, accumulated_data.l2_to_l1_msgs[i]); - } - for (size_t i = 0; i < MAX_UNENCRYPTED_LOGS_PER_CALL; i++) { - read(it, accumulated_data.unencrypted_logs_hashes[i]); - } - for (size_t i = 0; i < MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX; i++) { - read(it, accumulated_data.public_data_writes[i]); - } + + read(it, accumulated_data.note_hashes); + read(it, accumulated_data.nullifiers); + read(it, accumulated_data.l2_to_l1_msgs); + read(it, accumulated_data.unencrypted_logs_hashes); + read(it, accumulated_data.public_data_writes); }; class AvmPublicInputs { diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp index ecd740ca9ba..245e5188c93 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.cpp @@ -35,7 +35,9 @@ #include "barretenberg/vm/avm/trace/gadgets/slice_trace.hpp" #include "barretenberg/vm/avm/trace/helper.hpp" #include "barretenberg/vm/avm/trace/opcode.hpp" +#include "barretenberg/vm/avm/trace/public_inputs.hpp" #include "barretenberg/vm/avm/trace/trace.hpp" +#include "barretenberg/vm/aztec_constants.hpp" #include "barretenberg/vm/stats.hpp" namespace bb::avm_trace { @@ -134,6 +136,56 @@ bool check_tag_integral(AvmMemoryTag tag) * HELPERS **************************************************************************************************/ +void AvmTraceBuilder::checkpoint_non_revertible_state() +{ + merkle_tree_trace_builder.checkpoint_non_revertible_state(); +} +void AvmTraceBuilder::rollback_to_non_revertible_checkpoint() +{ + merkle_tree_trace_builder.rollback_to_non_revertible_checkpoint(); +} + +std::vector AvmTraceBuilder::get_bytecode(const FF contract_address, bool check_membership) +{ + // uint32_t clk = 0; + // auto clk = static_cast(main_trace.size()) + 1; + + // Find the bytecode based on contract address of the public call request + const AvmContractBytecode bytecode_hint = + *std::ranges::find_if(execution_hints.all_contract_bytecode, [contract_address](const auto& contract) { + return contract.contract_instance.address == contract_address; + }); + if (check_membership) { + // NullifierReadTreeHint nullifier_read_hint = bytecode_hint.contract_instance.membership_hint; + //// hinted nullifier should match the specified contract address + // ASSERT(nullifier_read_hint.low_leaf_preimage.nullifier == contract_address); + // bool is_member = merkle_tree_trace_builder.perform_nullifier_read(clk, + // nullifier_read_hint.low_leaf_preimage, + // nullifier_read_hint.low_leaf_index, + // nullifier_read_hint.low_leaf_sibling_path); + //// TODO(dbanks12): handle non-existent bytecode + //// if the contract address nullifier is hinted as "exists", the membership check should agree + // ASSERT(is_member); + } + + vinfo("Found bytecode for contract address: ", contract_address); + return bytecode_hint.bytecode; +} + +void AvmTraceBuilder::insert_private_state(const std::vector& siloed_nullifiers, + [[maybe_unused]] const std::vector& siloed_note_hashes) +{ + for (const auto& siloed_nullifier : siloed_nullifiers) { + auto hint = execution_hints.nullifier_write_hints[nullifier_write_counter++]; + merkle_tree_trace_builder.perform_nullifier_append(0, + hint.low_leaf_membership.low_leaf_preimage, + hint.low_leaf_membership.low_leaf_index, + hint.low_leaf_membership.low_leaf_sibling_path, + siloed_nullifier, + hint.insertion_path); + } +} + /** * @brief Loads a value from memory into a given intermediate register at a specified clock cycle. * Handles both direct and indirect memory access. @@ -306,17 +358,16 @@ AvmTraceBuilder::AvmTraceBuilder(AvmPublicInputs public_inputs, std::vector calldata) // NOTE: we initialise the environment builder here as it requires public inputs : calldata(std::move(calldata)) - , new_public_inputs(public_inputs) + , public_inputs(public_inputs) , side_effect_counter(side_effect_counter) , execution_hints(std::move(execution_hints_)) - , intermediate_tree_snapshots(public_inputs.start_tree_snapshots) , bytecode_trace_builder(execution_hints.all_contract_bytecode) + , merkle_tree_trace_builder(public_inputs.start_tree_snapshots) { // TODO: think about cast - gas_trace_builder.set_initial_gas(static_cast(new_public_inputs.gas_settings.gas_limits.l2_gas - - new_public_inputs.start_gas_used.l2_gas), - static_cast(new_public_inputs.gas_settings.gas_limits.da_gas - - new_public_inputs.start_gas_used.da_gas)); + gas_trace_builder.set_initial_gas( + static_cast(public_inputs.gas_settings.gas_limits.l2_gas - public_inputs.start_gas_used.l2_gas), + static_cast(public_inputs.gas_settings.gas_limits.da_gas - public_inputs.start_gas_used.da_gas)); } /************************************************************************************************** @@ -1626,7 +1677,7 @@ AvmError AvmTraceBuilder::op_function_selector(uint8_t indirect, uint32_t dst_of AvmError AvmTraceBuilder::op_transaction_fee(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.transaction_fee; + FF ia_value = public_inputs.transaction_fee; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_transaction_fee = FF(1); @@ -1656,7 +1707,7 @@ AvmError AvmTraceBuilder::op_is_static_call(uint8_t indirect, uint32_t dst_offse AvmError AvmTraceBuilder::op_chain_id(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.chain_id; + FF ia_value = public_inputs.global_variables.chain_id; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_chain_id = FF(1); @@ -1669,7 +1720,7 @@ AvmError AvmTraceBuilder::op_chain_id(uint8_t indirect, uint32_t dst_offset) AvmError AvmTraceBuilder::op_version(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.version; + FF ia_value = public_inputs.global_variables.version; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_version = FF(1); @@ -1682,7 +1733,7 @@ AvmError AvmTraceBuilder::op_version(uint8_t indirect, uint32_t dst_offset) AvmError AvmTraceBuilder::op_block_number(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.block_number; + FF ia_value = public_inputs.global_variables.block_number; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_block_number = FF(1); @@ -1695,7 +1746,7 @@ AvmError AvmTraceBuilder::op_block_number(uint8_t indirect, uint32_t dst_offset) AvmError AvmTraceBuilder::op_timestamp(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.timestamp; + FF ia_value = public_inputs.global_variables.timestamp; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::U64); row.main_sel_op_timestamp = FF(1); @@ -1708,7 +1759,7 @@ AvmError AvmTraceBuilder::op_timestamp(uint8_t indirect, uint32_t dst_offset) AvmError AvmTraceBuilder::op_fee_per_l2_gas(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.gas_fees.fee_per_l2_gas; + FF ia_value = public_inputs.global_variables.gas_fees.fee_per_l2_gas; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_fee_per_l2_gas = FF(1); @@ -1721,7 +1772,7 @@ AvmError AvmTraceBuilder::op_fee_per_l2_gas(uint8_t indirect, uint32_t dst_offse AvmError AvmTraceBuilder::op_fee_per_da_gas(uint8_t indirect, uint32_t dst_offset) { - FF ia_value = new_public_inputs.global_variables.gas_fees.fee_per_da_gas; + FF ia_value = public_inputs.global_variables.gas_fees.fee_per_da_gas; auto [row, error] = create_kernel_lookup_opcode(indirect, dst_offset, ia_value, AvmMemoryTag::FF); row.main_sel_op_fee_per_da_gas = FF(1); @@ -2564,10 +2615,9 @@ AvmError AvmTraceBuilder::op_sload(uint8_t indirect, uint32_t slot_offset, uint3 // Sanity check that the computed slot using the value read from slot_offset should match the read hint ASSERT(computed_tree_slot == read_hint.leaf_preimage.slot); - FF public_data_tree_root = intermediate_tree_snapshots.public_data_tree.root; // Check that the leaf is a member of the public data tree bool is_member = merkle_tree_trace_builder.perform_storage_read( - clk, read_hint.leaf_preimage, read_hint.leaf_index, read_hint.sibling_path, public_data_tree_root); + clk, read_hint.leaf_preimage, read_hint.leaf_index, read_hint.sibling_path); ASSERT(is_member); FF value = read_hint.leaf_preimage.value; @@ -2613,10 +2663,24 @@ AvmError AvmTraceBuilder::op_sload(uint8_t indirect, uint32_t slot_offset, uint3 AvmError AvmTraceBuilder::op_sstore(uint8_t indirect, uint32_t src_offset, uint32_t slot_offset) { // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto clk = static_cast(main_trace.size()) + 1; - // We keep the first encountered error - AvmError error = AvmError::NO_ERROR; + if (storage_write_counter >= MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_sstore = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::SSTORE); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::SSTORE); + return error; + } + auto [resolved_addrs, res_error] = Addressing<2>::fromWire(indirect, call_ptr).resolve({ src_offset, slot_offset }, mem_trace_builder); auto [resolved_src, resolved_slot] = resolved_addrs; @@ -2642,17 +2706,13 @@ AvmError AvmTraceBuilder::op_sstore(uint8_t indirect, uint32_t src_offset, uint3 // (e) We create a new preimage for the new write // (f) We compute the new root by updating at the leaf index with the hash of the new preimage PublicDataWriteTreeHint write_hint = execution_hints.storage_write_hints.at(storage_write_counter++); - FF root = merkle_tree_trace_builder.perform_storage_write(clk, - write_hint.low_leaf_membership.leaf_preimage, - write_hint.low_leaf_membership.leaf_index, - write_hint.low_leaf_membership.sibling_path, - write_hint.new_leaf_preimage.slot, - write_hint.new_leaf_preimage.value, - intermediate_tree_snapshots.public_data_tree.size, - write_hint.insertion_path, - intermediate_tree_snapshots.public_data_tree.root); - intermediate_tree_snapshots.public_data_tree.root = root; - intermediate_tree_snapshots.public_data_tree.size++; + merkle_tree_trace_builder.perform_storage_write(clk, + write_hint.low_leaf_membership.leaf_preimage, + write_hint.low_leaf_membership.leaf_index, + write_hint.low_leaf_membership.sibling_path, + write_hint.new_leaf_preimage.slot, + write_hint.new_leaf_preimage.value, + write_hint.insertion_path); // TODO(8945): remove fake rows Row row = Row{ @@ -2662,6 +2722,7 @@ AvmError AvmTraceBuilder::op_sstore(uint8_t indirect, uint32_t src_offset, uint3 .main_ind_addr_a = read_a.indirect_address, .main_internal_return_ptr = internal_return_ptr, .main_mem_addr_a = read_a.direct_address, // direct address incremented at end of the loop + .main_op_err = FF(static_cast(!is_ok(error))), .main_pc = pc, .main_r_in_tag = static_cast(AvmMemoryTag::FF), .main_sel_mem_op_a = 1, @@ -2712,11 +2773,9 @@ AvmError AvmTraceBuilder::op_note_hash_exists(uint8_t indirect, bool exists = note_hash_value == note_hash_read_hint.leaf_value; // Check membership of the leaf index in the note hash tree const auto leaf_index = unconstrained_read_from_memory(resolved_leaf_index); - bool is_member = - AvmMerkleTreeTraceBuilder::unconstrained_check_membership(note_hash_read_hint.leaf_value, - static_cast(leaf_index), - note_hash_read_hint.sibling_path, - intermediate_tree_snapshots.note_hash_tree.root); + bool is_member = merkle_tree_trace_builder.perform_note_hash_read( + clk, note_hash_read_hint.leaf_value, leaf_index, note_hash_read_hint.sibling_path); + ASSERT(is_member); // This already does memory reads @@ -2782,23 +2841,36 @@ AvmError AvmTraceBuilder::op_emit_note_hash(uint8_t indirect, uint32_t note_hash { auto const clk = static_cast(main_trace.size()) + 1; + if (note_hash_write_counter >= MAX_NOTE_HASHES_PER_TX) { + AvmError error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_note_hash = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::EMITNOTEHASH); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITNOTEHASH); + return error; + } + + auto [row, error] = create_kernel_output_opcode(indirect, clk, note_hash_offset); + row.main_sel_op_emit_note_hash = FF(1); + row.main_op_err = FF(static_cast(!is_ok(error))); + AppendTreeHint note_hash_write_hint = execution_hints.note_hash_write_hints.at(note_hash_write_counter++); + auto siloed_note_hash = AvmMerkleTreeTraceBuilder::unconstrained_silo_note_hash( + current_public_call_request.contract_address, row.main_ia); + ASSERT(row.main_ia == note_hash_write_hint.leaf_value); // We first check that the index is currently empty - auto insertion_index = static_cast(intermediate_tree_snapshots.note_hash_tree.size); - bool insert_index_is_empty = - AvmMerkleTreeTraceBuilder::unconstrained_check_membership(FF::zero(), - insertion_index, - note_hash_write_hint.sibling_path, - intermediate_tree_snapshots.note_hash_tree.root); + bool insert_index_is_empty = merkle_tree_trace_builder.perform_note_hash_read( + clk, FF::zero(), note_hash_write_hint.leaf_index, note_hash_write_hint.sibling_path); ASSERT(insert_index_is_empty); - // Update the root with the new leaf that is appended - FF new_root = AvmMerkleTreeTraceBuilder::unconstrained_update_leaf_index( - note_hash_write_hint.leaf_value, insertion_index, note_hash_write_hint.sibling_path); - intermediate_tree_snapshots.note_hash_tree.root = new_root; - intermediate_tree_snapshots.note_hash_tree.size++; - auto [row, error] = create_kernel_output_opcode(indirect, clk, note_hash_offset); - row.main_sel_op_emit_note_hash = FF(1); + // Update the root with the new leaf that is appended + merkle_tree_trace_builder.perform_note_hash_append(clk, siloed_note_hash, note_hash_write_hint.sibling_path); // Constrain gas cost gas_trace_builder.constrain_gas(clk, OpCode::EMITNOTEHASH); @@ -2840,12 +2912,10 @@ AvmError AvmTraceBuilder::op_nullifier_exists(uint8_t indirect, FF nullifier_value = unconstrained_read_from_memory(resolved_nullifier_offset); FF address_value = unconstrained_read_from_memory(resolved_address); FF siloed_nullifier = AvmMerkleTreeTraceBuilder::unconstrained_silo_nullifier(address_value, nullifier_value); - bool is_member = - merkle_tree_trace_builder.perform_nullifier_read(clk, - nullifier_read_hint.low_leaf_preimage, - nullifier_read_hint.low_leaf_index, - nullifier_read_hint.low_leaf_sibling_path, - intermediate_tree_snapshots.nullifier_tree.root); + bool is_member = merkle_tree_trace_builder.perform_nullifier_read(clk, + nullifier_read_hint.low_leaf_preimage, + nullifier_read_hint.low_leaf_index, + nullifier_read_hint.low_leaf_sibling_path); ASSERT(is_member); if (siloed_nullifier == nullifier_read_hint.low_leaf_preimage.nullifier) { @@ -2921,48 +2991,67 @@ AvmError AvmTraceBuilder::op_nullifier_exists(uint8_t indirect, AvmError AvmTraceBuilder::op_emit_nullifier(uint8_t indirect, uint32_t nullifier_offset) { + // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto const clk = static_cast(main_trace.size()) + 1; - auto [row, error] = create_kernel_output_opcode(indirect, clk, nullifier_offset); + if (nullifier_write_counter >= MAX_NULLIFIERS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_nullifier = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::EMITNULLIFIER); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITNULLIFIER); + return error; + } + + auto [row, output_error] = create_kernel_output_opcode(indirect, clk, nullifier_offset); row.main_sel_op_emit_nullifier = FF(1); + if (is_ok(error)) { + error = output_error; + } // Do merkle check FF nullifier_value = row.main_ia; FF siloed_nullifier = AvmMerkleTreeTraceBuilder::unconstrained_silo_nullifier( current_public_call_request.contract_address, nullifier_value); - // This is a little bit fragile - but we use the fact that if we traced a nullifier that already exists (which is - // invalid), we would have stored it under a read hint. - NullifierReadTreeHint nullifier_read_hint = execution_hints.nullifier_read_hints.at(nullifier_read_counter); - bool is_update = merkle_tree_trace_builder.perform_nullifier_read(clk, - nullifier_read_hint.low_leaf_preimage, - nullifier_read_hint.low_leaf_index, - nullifier_read_hint.low_leaf_sibling_path, - intermediate_tree_snapshots.nullifier_tree.root); + NullifierWriteTreeHint nullifier_write_hint = execution_hints.nullifier_write_hints.at(nullifier_write_counter++); + bool is_update = siloed_nullifier == nullifier_write_hint.low_leaf_membership.low_leaf_preimage.next_nullifier; if (is_update) { - // If we are in this branch, then the nullifier already exists in the tree - // WE NEED TO RAISE AN ERROR FLAG HERE - for now we do nothing, except increment the counter - + // hinted low-leaf points to the target nullifier, so it already exists + // prove membership of that low-leaf, which also proves membership of the target nullifier + bool exists = merkle_tree_trace_builder.perform_nullifier_read( + clk, + nullifier_write_hint.low_leaf_membership.low_leaf_preimage, + nullifier_write_hint.low_leaf_membership.low_leaf_index, + nullifier_write_hint.low_leaf_membership.low_leaf_sibling_path); + // if hinted low-leaf that skips the nullifier fails membership check, bad hint! + ASSERT(exists); nullifier_read_counter++; - error = AvmError::DUPLICATE_NULLIFIER; + // Cannot update an existing nullifier, and cannot emit a duplicate. Error! + if (is_ok(error)) { + error = AvmError::DUPLICATE_NULLIFIER; + } } else { - // This is a non-membership proof which means our insertion is valid - NullifierWriteTreeHint nullifier_write_hint = - execution_hints.nullifier_write_hints.at(nullifier_write_counter++); - FF new_root = merkle_tree_trace_builder.perform_nullifier_append( + // hinted low-leaf SKIPS the target nullifier, so it does NOT exist + // prove membership of the low leaf which also proves non-membership of the target nullifier + merkle_tree_trace_builder.perform_nullifier_append( clk, nullifier_write_hint.low_leaf_membership.low_leaf_preimage, nullifier_write_hint.low_leaf_membership.low_leaf_index, nullifier_write_hint.low_leaf_membership.low_leaf_sibling_path, siloed_nullifier, - intermediate_tree_snapshots.nullifier_tree.size, - nullifier_write_hint.insertion_path, - intermediate_tree_snapshots.nullifier_tree.root); - - intermediate_tree_snapshots.nullifier_tree.root = new_root; - intermediate_tree_snapshots.nullifier_tree.size++; + nullifier_write_hint.insertion_path); } + row.main_op_err = FF(static_cast(!is_ok(error))); + // Constrain gas cost gas_trace_builder.constrain_gas(clk, OpCode::EMITNULLIFIER); @@ -3006,11 +3095,8 @@ AvmError AvmTraceBuilder::op_l1_to_l2_msg_exists(uint8_t indirect, bool exists = l1_to_l2_msg_value == l1_to_l2_msg_read_hint.leaf_value; // Check membership of the leaf index in the l1_to_l2_msg tree - bool is_member = AvmMerkleTreeTraceBuilder::unconstrained_check_membership( - l1_to_l2_msg_read_hint.leaf_value, - static_cast(l1_to_l2_msg_read_hint.leaf_index), - l1_to_l2_msg_read_hint.sibling_path, - intermediate_tree_snapshots.l1_to_l2_message_tree.root); + bool is_member = merkle_tree_trace_builder.perform_l1_to_l2_message_read( + clk, l1_to_l2_msg_read_hint.leaf_value, leaf_index, l1_to_l2_msg_read_hint.sibling_path); ASSERT(is_member); auto read_a = constrained_read_from_memory( @@ -3235,6 +3321,26 @@ AvmError AvmTraceBuilder::op_emit_unencrypted_log(uint8_t indirect, uint32_t log }; } + // Can't return earlier as we do elsewhere for side-effect-limit because we need + // to at least retrieve log_size first to charge proper gas. + // This means a tag error could occur before side-effect-limit first. + if (is_ok(error) && unencrypted_log_write_counter >= MAX_UNENCRYPTED_LOGS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_unencrypted_log = FF(1), + }; + // Constrain gas cost + gas_trace_builder.constrain_gas(clk, OpCode::EMITUNENCRYPTEDLOG, static_cast(log_size)); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::EMITUNENCRYPTEDLOG); + return error; + } + unencrypted_log_write_counter++; + if (is_ok(error)) { // We need to read the rest of the log_size number of elements for (uint32_t i = 0; i < log_size; i++) { @@ -3289,14 +3395,38 @@ AvmError AvmTraceBuilder::op_emit_unencrypted_log(uint8_t indirect, uint32_t log AvmError AvmTraceBuilder::op_emit_l2_to_l1_msg(uint8_t indirect, uint32_t recipient_offset, uint32_t content_offset) { + // We keep the first encountered error + AvmError error = AvmError::NO_ERROR; auto const clk = static_cast(main_trace.size()) + 1; + if (l2_to_l1_msg_write_counter >= MAX_L2_TO_L1_MSGS_PER_TX) { + error = AvmError::SIDE_EFFECT_LIMIT_REACHED; + auto row = Row{ + .main_clk = clk, + .main_internal_return_ptr = internal_return_ptr, + .main_op_err = FF(static_cast(!is_ok(error))), + .main_pc = pc, + .main_sel_op_emit_l2_to_l1_msg = FF(1), + }; + gas_trace_builder.constrain_gas(clk, OpCode::SENDL2TOL1MSG); + main_trace.push_back(row); + pc += Deserialization::get_pc_increment(OpCode::SENDL2TOL1MSG); + return error; + } + l2_to_l1_msg_write_counter++; + // Note: unorthodox order - as seen in L2ToL1Message struct in TS - auto [row, error] = create_kernel_output_opcode_with_metadata( + auto [row, output_error] = create_kernel_output_opcode_with_metadata( indirect, clk, content_offset, AvmMemoryTag::FF, recipient_offset, AvmMemoryTag::FF); + + if (is_ok(error)) { + error = output_error; + } + // Wtite to output // kernel_trace_builder.op_emit_l2_to_l1_msg(clk, side_effect_counter, row.main_ia, row.main_ib); row.main_sel_op_emit_l2_to_l1_msg = FF(1); + row.main_op_err = FF(static_cast(!is_ok(error))); // Constrain gas cost gas_trace_builder.constrain_gas(clk, OpCode::SENDL2TOL1MSG); @@ -3617,6 +3747,10 @@ ReturnDataError AvmTraceBuilder::op_revert(uint8_t indirect, uint32_t ret_offset pc = UINT32_MAX; // This ensures that no subsequent opcode will be executed. + if (is_ok(error)) { + error = AvmError::REVERT_OPCODE; + } + // op_valid == true otherwise, ret_size == 0 and we would have returned above. return ReturnDataError{ .return_data = returndata, @@ -4334,8 +4468,13 @@ AvmError AvmTraceBuilder::op_to_radix_be(uint8_t indirect, * * @return The main trace */ -std::vector AvmTraceBuilder::finalize() +std::vector AvmTraceBuilder::finalize(bool apply_end_gas_assertions) { + // Some sanity checks + // Check that the final merkle tree lines up with the public inputs + TreeSnapshots tree_snapshots = merkle_tree_trace_builder.get_tree_snapshots(); + ASSERT(tree_snapshots == public_inputs.end_tree_snapshots); + vinfo("range_check_required: ", range_check_required); vinfo("full_precomputed_tables: ", full_precomputed_tables); @@ -4356,7 +4495,6 @@ std::vector AvmTraceBuilder::finalize() size_t bin_trace_size = bin_trace_builder.size(); size_t gas_trace_size = gas_trace_builder.size(); size_t slice_trace_size = slice_trace.size(); - // size_t kernel_trace_size = kernel_trace_builder.size(); // Range check size is 1 less than it needs to be since we insert a "first row" at the top of the trace at the // end, with clk 0 (this doubles as our range check) @@ -4599,6 +4737,16 @@ std::vector AvmTraceBuilder::finalize() gas_trace_builder.finalize(main_trace); + if (apply_end_gas_assertions) { + // Sanity check that the amount of gas consumed matches what we expect from the public inputs + auto last_l2_gas_remaining = main_trace.back().main_l2_gas_remaining; + auto expected_end_gas_l2 = public_inputs.gas_settings.gas_limits.l2_gas - public_inputs.end_gas_used.l2_gas; + ASSERT(last_l2_gas_remaining == expected_end_gas_l2); + auto last_da_gas_remaining = main_trace.back().main_da_gas_remaining; + auto expected_end_gas_da = public_inputs.gas_settings.gas_limits.da_gas - public_inputs.end_gas_used.da_gas; + ASSERT(last_da_gas_remaining == expected_end_gas_da); + } + /********************************************************************************************** * KERNEL TRACE INCLUSION **********************************************************************************************/ diff --git a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp index aed311ee443..d62fff47f0b 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm/trace/trace.hpp @@ -15,7 +15,6 @@ #include "barretenberg/vm/avm/trace/gadgets/sha256.hpp" #include "barretenberg/vm/avm/trace/gadgets/slice_trace.hpp" #include "barretenberg/vm/avm/trace/gas_trace.hpp" -// #include "barretenberg/vm/avm/trace/kernel_trace.hpp" #include "barretenberg/vm/avm/trace/mem_trace.hpp" #include "barretenberg/vm/avm/trace/opcode.hpp" #include "barretenberg/vm/avm/trace/public_inputs.hpp" @@ -42,7 +41,7 @@ struct RowWithError { class AvmTraceBuilder { public: - AvmTraceBuilder(AvmPublicInputs new_public_inputs = {}, + AvmTraceBuilder(AvmPublicInputs public_inputs, ExecutionHints execution_hints = {}, uint32_t side_effect_counter = 0, std::vector calldata = {}); @@ -222,9 +221,14 @@ class AvmTraceBuilder { uint32_t num_limbs, uint8_t output_bits); - std::vector finalize(); + std::vector finalize(bool apply_end_gas_assertions = false); void reset(); + void checkpoint_non_revertible_state(); + void rollback_to_non_revertible_checkpoint(); + std::vector get_bytecode(const FF contract_address, bool check_membership = false); + void insert_private_state(const std::vector& siloed_nullifiers, const std::vector& siloed_note_hashes); + // These are used for testing only. AvmTraceBuilder& set_range_check_required(bool required) { @@ -250,7 +254,7 @@ class AvmTraceBuilder { std::vector main_trace; std::vector calldata; - AvmPublicInputs new_public_inputs; + AvmPublicInputs public_inputs; PublicCallRequest current_public_call_request; std::vector returndata; @@ -261,16 +265,16 @@ class AvmTraceBuilder { uint32_t side_effect_counter = 0; uint32_t external_call_counter = 0; // Incremented both by OpCode::CALL and OpCode::STATICCALL ExecutionHints execution_hints; - // These are the tracked roots for intermediate steps - TreeSnapshots intermediate_tree_snapshots; // These are some counters for the tree acceess hints that we probably dont need in the future uint32_t note_hash_read_counter = 0; uint32_t note_hash_write_counter = 0; uint32_t nullifier_read_counter = 0; uint32_t nullifier_write_counter = 0; uint32_t l1_to_l2_msg_read_counter = 0; + uint32_t l2_to_l1_msg_write_counter = 0; uint32_t storage_read_counter = 0; uint32_t storage_write_counter = 0; + uint32_t unencrypted_log_write_counter = 0; // These exist due to testing only. bool range_check_required = true; diff --git a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr index 93c07de02a2..1a4c3fee584 100644 --- a/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr +++ b/noir-projects/noir-contracts/contracts/avm_test_contract/src/main.nr @@ -473,6 +473,41 @@ contract AvmTest { context.push_nullifier(nullifier); } + #[public] + fn n_storage_writes(num: u32) { + for i in 0..num { + storage.map.at(AztecAddress::from_field(i as Field)).write(i); + } + } + + #[public] + fn n_new_note_hashes(num: u32) { + for i in 0..num { + context.push_note_hash(i as Field); + } + } + + #[public] + fn n_new_nullifiers(num: u32) { + for i in 0..num { + context.push_nullifier(i as Field); + } + } + + #[public] + fn n_new_l2_to_l1_msgs(num: u32) { + for i in 0..num { + context.message_portal(EthAddress::from_field(i as Field), i as Field) + } + } + + #[public] + fn n_new_unencrypted_logs(num: u32) { + for i in 0..num { + context.emit_unencrypted_log(/*message=*/ [i as Field]); + } + } + // Use the standard context interface to check for a nullifier #[public] fn nullifier_exists(nullifier: Field) -> bool { diff --git a/yarn-project/bb-prover/src/avm_proving.test.ts b/yarn-project/bb-prover/src/avm_proving.test.ts index 6195eb0850a..80c2f117205 100644 --- a/yarn-project/bb-prover/src/avm_proving.test.ts +++ b/yarn-project/bb-prover/src/avm_proving.test.ts @@ -1,4 +1,11 @@ -import { VerificationKeyData } from '@aztec/circuits.js'; +import { + MAX_L2_TO_L1_MSGS_PER_TX, + MAX_NOTE_HASHES_PER_TX, + MAX_NULLIFIERS_PER_TX, + MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, + MAX_UNENCRYPTED_LOGS_PER_TX, + VerificationKeyData, +} from '@aztec/circuits.js'; import { Fr } from '@aztec/foundation/fields'; import { createDebugLogger } from '@aztec/foundation/log'; import { simulateAvmTestContractGenerateCircuitInputs } from '@aztec/simulator/public/fixtures'; @@ -10,17 +17,78 @@ import path from 'path'; import { type BBSuccess, BB_RESULT, generateAvmProof, verifyAvmProof } from './bb/execute.js'; import { extractAvmVkData } from './verification_key/verification_key_data.js'; +const TIMEOUT = 180_000; + describe('AVM WitGen, proof generation and verification', () => { - it('Should prove and verify bulk_testing', async () => { - await proveAndVerifyAvmTestContract( - 'bulk_testing', - [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), - ); - }, 180_000); + it( + 'Should prove and verify bulk_testing', + async () => { + await proveAndVerifyAvmTestContract( + 'bulk_testing', + /*calldata=*/ [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10].map(x => new Fr(x)), + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that performs too many storage writes and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_storage_writes', + /*calldata=*/ [new Fr(MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many note hashes and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_note_hashes', + /*calldata=*/ [new Fr(MAX_NOTE_HASHES_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many nullifiers and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_nullifiers', + /*calldata=*/ [new Fr(MAX_NULLIFIERS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many l2tol1 messages and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_l2_to_l1_msgs', + /*calldata=*/ [new Fr(MAX_L2_TO_L1_MSGS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); + it( + 'Should prove and verify test that creates too many unencrypted logs and reverts', + async () => { + await proveAndVerifyAvmTestContract( + 'n_new_unencrypted_logs', + /*calldata=*/ [new Fr(MAX_UNENCRYPTED_LOGS_PER_TX + 1)], + /*expectRevert=*/ true, + ); + }, + TIMEOUT, + ); }); -async function proveAndVerifyAvmTestContract(functionName: string, calldata: Fr[] = []) { - const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs(functionName, calldata); +async function proveAndVerifyAvmTestContract(functionName: string, calldata: Fr[] = [], expectRevert = false) { + const avmCircuitInputs = await simulateAvmTestContractGenerateCircuitInputs(functionName, calldata, expectRevert); const internalLogger = createDebugLogger('aztec:avm-proving-test'); const logger = (msg: string, _data?: any) => internalLogger.verbose(msg); diff --git a/yarn-project/bb-prover/src/bb/execute.ts b/yarn-project/bb-prover/src/bb/execute.ts index b8acb266fd7..76261a8d31d 100644 --- a/yarn-project/bb-prover/src/bb/execute.ts +++ b/yarn-project/bb-prover/src/bb/execute.ts @@ -15,7 +15,6 @@ export const VK_FIELDS_FILENAME = 'vk_fields.json'; export const PROOF_FILENAME = 'proof'; export const PROOF_FIELDS_FILENAME = 'proof_fields.json'; export const AVM_BYTECODE_FILENAME = 'avm_bytecode.bin'; -export const AVM_CALLDATA_FILENAME = 'avm_calldata.bin'; export const AVM_PUBLIC_INPUTS_FILENAME = 'avm_public_inputs.bin'; export const AVM_HINTS_FILENAME = 'avm_hints.bin'; @@ -519,7 +518,6 @@ export async function generateAvmProof( } // Paths for the inputs - const calldataPath = join(workingDirectory, AVM_CALLDATA_FILENAME); const publicInputsPath = join(workingDirectory, AVM_PUBLIC_INPUTS_FILENAME); const avmHintsPath = join(workingDirectory, AVM_HINTS_FILENAME); @@ -539,13 +537,6 @@ export async function generateAvmProof( try { // Write the inputs to the working directory. - await fs.writeFile( - calldataPath, - input.calldata.map(fr => fr.toBuffer()), - ); - if (!filePresent(calldataPath)) { - return { status: BB_RESULT.FAILURE, reason: `Could not write calldata at ${calldataPath}` }; - } await fs.writeFile(publicInputsPath, input.output.toBuffer()); if (!filePresent(publicInputsPath)) { @@ -558,8 +549,6 @@ export async function generateAvmProof( } const args = [ - '--avm-calldata', - calldataPath, '--avm-public-inputs', publicInputsPath, '--avm-hints', diff --git a/yarn-project/bb-prover/src/test/index.ts b/yarn-project/bb-prover/src/test/index.ts index 555536e8cb7..3f84ad27da1 100644 --- a/yarn-project/bb-prover/src/test/index.ts +++ b/yarn-project/bb-prover/src/test/index.ts @@ -1,3 +1,2 @@ export * from './test_circuit_prover.js'; export * from './test_verifier.js'; -export * from './test_avm.js'; diff --git a/yarn-project/bb-prover/src/test/test_avm.ts b/yarn-project/bb-prover/src/test/test_avm.ts deleted file mode 100644 index 4cbac8bb1c4..00000000000 --- a/yarn-project/bb-prover/src/test/test_avm.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { - AztecAddress, - ContractStorageRead, - ContractStorageUpdateRequest, - Gas, - GlobalVariables, - Header, - L2ToL1Message, - LogHash, - MAX_ENQUEUED_CALLS_PER_CALL, - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - MAX_L2_TO_L1_MSGS_PER_CALL, - MAX_NOTE_HASHES_PER_CALL, - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - MAX_NULLIFIERS_PER_CALL, - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - MAX_PUBLIC_DATA_READS_PER_CALL, - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - MAX_UNENCRYPTED_LOGS_PER_CALL, - NoteHash, - Nullifier, - PublicCircuitPublicInputs, - PublicInnerCallRequest, - ReadRequest, - RevertCode, - TreeLeafReadRequest, -} from '@aztec/circuits.js'; -import { computeVarArgsHash } from '@aztec/circuits.js/hash'; -import { padArrayEnd } from '@aztec/foundation/collection'; -import { type PublicFunctionCallResult } from '@aztec/simulator'; - -// TODO: pub somewhere more usable - copied from abstract phase manager -export function getPublicInputs(result: PublicFunctionCallResult): PublicCircuitPublicInputs { - return PublicCircuitPublicInputs.from({ - callContext: result.executionRequest.callContext, - proverAddress: AztecAddress.ZERO, - argsHash: computeVarArgsHash(result.executionRequest.args), - noteHashes: padArrayEnd(result.noteHashes, NoteHash.empty(), MAX_NOTE_HASHES_PER_CALL), - nullifiers: padArrayEnd(result.nullifiers, Nullifier.empty(), MAX_NULLIFIERS_PER_CALL), - l2ToL1Msgs: padArrayEnd(result.l2ToL1Messages, L2ToL1Message.empty(), MAX_L2_TO_L1_MSGS_PER_CALL), - startSideEffectCounter: result.startSideEffectCounter, - endSideEffectCounter: result.endSideEffectCounter, - returnsHash: computeVarArgsHash(result.returnValues), - noteHashReadRequests: padArrayEnd( - result.noteHashReadRequests, - TreeLeafReadRequest.empty(), - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - ), - nullifierReadRequests: padArrayEnd( - result.nullifierReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - ), - nullifierNonExistentReadRequests: padArrayEnd( - result.nullifierNonExistentReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - ), - l1ToL2MsgReadRequests: padArrayEnd( - result.l1ToL2MsgReadRequests, - TreeLeafReadRequest.empty(), - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - ), - contractStorageReads: padArrayEnd( - result.contractStorageReads, - ContractStorageRead.empty(), - MAX_PUBLIC_DATA_READS_PER_CALL, - ), - contractStorageUpdateRequests: padArrayEnd( - result.contractStorageUpdateRequests, - ContractStorageUpdateRequest.empty(), - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - ), - publicCallRequests: padArrayEnd([], PublicInnerCallRequest.empty(), MAX_ENQUEUED_CALLS_PER_CALL), - unencryptedLogsHashes: padArrayEnd(result.unencryptedLogsHashes, LogHash.empty(), MAX_UNENCRYPTED_LOGS_PER_CALL), - historicalHeader: Header.empty(), - globalVariables: GlobalVariables.empty(), - startGasLeft: Gas.from(result.startGasLeft), - endGasLeft: Gas.from(result.endGasLeft), - transactionFee: result.transactionFee, - // TODO(@just-mitch): need better mapping from simulator to revert code. - revertCode: result.reverted ? RevertCode.APP_LOGIC_REVERTED : RevertCode.OK, - }); -} diff --git a/yarn-project/circuits.js/src/structs/avm/avm.ts b/yarn-project/circuits.js/src/structs/avm/avm.ts index 30c37a7b132..e50308e2b73 100644 --- a/yarn-project/circuits.js/src/structs/avm/avm.ts +++ b/yarn-project/circuits.js/src/structs/avm/avm.ts @@ -851,13 +851,13 @@ export class AvmExecutionHints { public readonly contractInstances: Vector; public readonly contractBytecodeHints: Vector; - public readonly storageReadRequest: Vector; - public readonly storageUpdateRequest: Vector; - public readonly nullifierReadRequest: Vector; - public readonly nullifierWriteHints: Vector; - public readonly noteHashReadRequest: Vector; - public readonly noteHashWriteRequest: Vector; - public readonly l1ToL2MessageReadRequest: Vector; + public readonly publicDataReads: Vector; + public readonly publicDataWrites: Vector; + public readonly nullifierReads: Vector; + public readonly nullifierWrites: Vector; + public readonly noteHashReads: Vector; + public readonly noteHashWrites: Vector; + public readonly l1ToL2MessageReads: Vector; constructor( enqueuedCalls: AvmEnqueuedCallHint[], @@ -868,13 +868,13 @@ export class AvmExecutionHints { externalCalls: AvmExternalCallHint[], contractInstances: AvmContractInstanceHint[], contractBytecodeHints: AvmContractBytecodeHints[], - storageReadRequest: AvmPublicDataReadTreeHint[], - storageUpdateRequest: AvmPublicDataWriteTreeHint[], - nullifierReadRequest: AvmNullifierReadTreeHint[], - nullifierWriteHints: AvmNullifierWriteTreeHint[], - noteHashReadRequest: AvmAppendTreeHint[], - noteHashWriteRequest: AvmAppendTreeHint[], - l1ToL2MessageReadRequest: AvmAppendTreeHint[], + publicDataReads: AvmPublicDataReadTreeHint[], + publicDataWrites: AvmPublicDataWriteTreeHint[], + nullifierReads: AvmNullifierReadTreeHint[], + nullifierWrites: AvmNullifierWriteTreeHint[], + noteHashReads: AvmAppendTreeHint[], + noteHashWrites: AvmAppendTreeHint[], + l1ToL2MessageReads: AvmAppendTreeHint[], ) { this.enqueuedCalls = new Vector(enqueuedCalls); this.storageValues = new Vector(storageValues); @@ -884,14 +884,13 @@ export class AvmExecutionHints { this.externalCalls = new Vector(externalCalls); this.contractInstances = new Vector(contractInstances); this.contractBytecodeHints = new Vector(contractBytecodeHints); - this.storageReadRequest = new Vector(storageReadRequest); - this.storageUpdateRequest = new Vector(storageUpdateRequest); - this.noteHashReadRequest = new Vector(noteHashReadRequest); - this.nullifierReadRequest = new Vector(nullifierReadRequest); - this.nullifierWriteHints = new Vector(nullifierWriteHints); - this.noteHashReadRequest = new Vector(noteHashReadRequest); - this.noteHashWriteRequest = new Vector(noteHashWriteRequest); - this.l1ToL2MessageReadRequest = new Vector(l1ToL2MessageReadRequest); + this.publicDataReads = new Vector(publicDataReads); + this.publicDataWrites = new Vector(publicDataWrites); + this.nullifierReads = new Vector(nullifierReads); + this.nullifierWrites = new Vector(nullifierWrites); + this.noteHashReads = new Vector(noteHashReads); + this.noteHashWrites = new Vector(noteHashWrites); + this.l1ToL2MessageReads = new Vector(l1ToL2MessageReads); } /** @@ -932,13 +931,13 @@ export class AvmExecutionHints { this.externalCalls.items.length == 0 && this.contractInstances.items.length == 0 && this.contractBytecodeHints.items.length == 0 && - this.storageReadRequest.items.length == 0 && - this.storageUpdateRequest.items.length == 0 && - this.nullifierReadRequest.items.length == 0 && - this.nullifierWriteHints.items.length == 0 && - this.noteHashReadRequest.items.length == 0 && - this.noteHashWriteRequest.items.length == 0 && - this.l1ToL2MessageReadRequest.items.length == 0 + this.publicDataReads.items.length == 0 && + this.publicDataWrites.items.length == 0 && + this.nullifierReads.items.length == 0 && + this.nullifierWrites.items.length == 0 && + this.noteHashReads.items.length == 0 && + this.noteHashWrites.items.length == 0 && + this.l1ToL2MessageReads.items.length == 0 ); } @@ -949,8 +948,7 @@ export class AvmExecutionHints { */ static from(fields: FieldsOf): AvmExecutionHints { return new AvmExecutionHints( - // omit enqueued call hints until they're implemented in C++ - new Array(), + fields.enqueuedCalls.items, fields.storageValues.items, fields.noteHashExists.items, fields.nullifierExists.items, @@ -958,13 +956,13 @@ export class AvmExecutionHints { fields.externalCalls.items, fields.contractInstances.items, fields.contractBytecodeHints.items, - fields.storageReadRequest.items, - fields.storageUpdateRequest.items, - fields.nullifierReadRequest.items, - fields.nullifierWriteHints.items, - fields.noteHashReadRequest.items, - fields.noteHashWriteRequest.items, - fields.l1ToL2MessageReadRequest.items, + fields.publicDataReads.items, + fields.publicDataWrites.items, + fields.nullifierReads.items, + fields.nullifierWrites.items, + fields.noteHashReads.items, + fields.noteHashWrites.items, + fields.l1ToL2MessageReads.items, ); } @@ -975,8 +973,7 @@ export class AvmExecutionHints { */ static getFields(fields: FieldsOf) { return [ - // omit enqueued call hints until they're implemented in C++ - //fields.enqueuedCalls, + fields.enqueuedCalls, fields.storageValues, fields.noteHashExists, fields.nullifierExists, @@ -984,13 +981,13 @@ export class AvmExecutionHints { fields.externalCalls, fields.contractInstances, fields.contractBytecodeHints, - fields.storageReadRequest, - fields.storageUpdateRequest, - fields.nullifierReadRequest, - fields.nullifierWriteHints, - fields.noteHashReadRequest, - fields.noteHashWriteRequest, - fields.l1ToL2MessageReadRequest, + fields.publicDataReads, + fields.publicDataWrites, + fields.nullifierReads, + fields.nullifierWrites, + fields.noteHashReads, + fields.noteHashWrites, + fields.l1ToL2MessageReads, ] as const; } @@ -1002,8 +999,7 @@ export class AvmExecutionHints { static fromBuffer(buff: Buffer | BufferReader): AvmExecutionHints { const reader = BufferReader.asReader(buff); return new AvmExecutionHints( - // omit enqueued call hints until they're implemented in C++ - new Array(), + reader.readVector(AvmEnqueuedCallHint), reader.readVector(AvmKeyValueHint), reader.readVector(AvmKeyValueHint), reader.readVector(AvmKeyValueHint), diff --git a/yarn-project/circuits.js/src/tests/factories.ts b/yarn-project/circuits.js/src/tests/factories.ts index 79d5d63d2c6..4942e174577 100644 --- a/yarn-project/circuits.js/src/tests/factories.ts +++ b/yarn-project/circuits.js/src/tests/factories.ts @@ -1398,13 +1398,13 @@ export function makeAvmExecutionHints( externalCalls: makeVector(baseLength + 4, makeAvmExternalCallHint, seed + 0x4600), contractInstances: makeVector(baseLength + 5, makeAvmContractInstanceHint, seed + 0x4700), contractBytecodeHints: makeVector(baseLength + 6, makeAvmBytecodeHints, seed + 0x4800), - storageReadRequest: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900), - storageUpdateRequest: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00), - nullifierReadRequest: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00), - nullifierWriteHints: makeVector(baseLength + 10, makeAvmNullifierInsertionTreeHints, seed + 0x4c00), - noteHashReadRequest: makeVector(baseLength + 11, makeAvmTreeHints, seed + 0x4d00), - noteHashWriteRequest: makeVector(baseLength + 12, makeAvmTreeHints, seed + 0x4e00), - l1ToL2MessageReadRequest: makeVector(baseLength + 13, makeAvmTreeHints, seed + 0x4f00), + publicDataReads: makeVector(baseLength + 7, makeAvmStorageReadTreeHints, seed + 0x4900), + publicDataWrites: makeVector(baseLength + 8, makeAvmStorageUpdateTreeHints, seed + 0x4a00), + nullifierReads: makeVector(baseLength + 9, makeAvmNullifierReadTreeHints, seed + 0x4b00), + nullifierWrites: makeVector(baseLength + 10, makeAvmNullifierInsertionTreeHints, seed + 0x4c00), + noteHashReads: makeVector(baseLength + 11, makeAvmTreeHints, seed + 0x4d00), + noteHashWrites: makeVector(baseLength + 12, makeAvmTreeHints, seed + 0x4e00), + l1ToL2MessageReads: makeVector(baseLength + 13, makeAvmTreeHints, seed + 0x4f00), ...overrides, }); } diff --git a/yarn-project/simulator/src/avm/journal/journal.ts b/yarn-project/simulator/src/avm/journal/journal.ts index 63dbf59f09e..452e9d7267c 100644 --- a/yarn-project/simulator/src/avm/journal/journal.ts +++ b/yarn-project/simulator/src/avm/journal/journal.ts @@ -386,6 +386,11 @@ export class AvmPersistableStateManager { // Cache pending nullifiers for later access await this.nullifiers.append(siloedNullifier); // We append the new nullifier + this.log.debug( + `Nullifier tree root before insertion ${this.merkleTrees.treeMap + .get(MerkleTreeId.NULLIFIER_TREE)! + .getRoot()}`, + ); const appendResult = await this.merkleTrees.appendNullifier(siloedNullifier); this.log.debug( `Nullifier tree root after insertion ${this.merkleTrees.treeMap.get(MerkleTreeId.NULLIFIER_TREE)!.getRoot()}`, diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts index 6f84f4de2ad..d21f38dee71 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.test.ts @@ -59,7 +59,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getCounter()).toBe(startCounterPlus1); const expected = new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, siblingPath); - expect(trace.getAvmCircuitHints().storageReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().publicDataReads.items).toEqual([expected]); }); it('Should trace storage writes', () => { @@ -84,14 +84,14 @@ describe('Enqueued-call Side Effect Trace', () => { const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath); const expectedHint = new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, siblingPath); - expect(trace.getAvmCircuitHints().storageUpdateRequest.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().publicDataWrites.items).toEqual([expectedHint]); }); it('Should trace note hash checks', () => { const exists = true; trace.traceNoteHashCheck(address, utxo, leafIndex, exists, siblingPath); const expected = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().noteHashReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().noteHashReads.items).toEqual([expected]); }); it('Should trace note hashes', () => { @@ -102,7 +102,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getSideEffects().noteHashes).toEqual(expected); const expectedHint = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().noteHashWriteRequest.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().noteHashWrites.items).toEqual([expectedHint]); }); it('Should trace nullifier checks', () => { @@ -112,7 +112,7 @@ describe('Enqueued-call Side Effect Trace', () => { expect(trace.getCounter()).toBe(startCounterPlus1); const expected = new AvmNullifierReadTreeHint(lowLeafPreimage, leafIndex, siblingPath); - expect(trace.getAvmCircuitHints().nullifierReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().nullifierReads.items).toEqual([expected]); }); it('Should trace nullifiers', () => { @@ -125,14 +125,14 @@ describe('Enqueued-call Side Effect Trace', () => { const readHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafSiblingPath); const expectedHint = new AvmNullifierWriteTreeHint(readHint, siblingPath); - expect(trace.getAvmCircuitHints().nullifierWriteHints.items).toEqual([expectedHint]); + expect(trace.getAvmCircuitHints().nullifierWrites.items).toEqual([expectedHint]); }); it('Should trace L1ToL2 Message checks', () => { const exists = true; trace.traceL1ToL2MessageCheck(address, utxo, leafIndex, exists, siblingPath); const expected = new AvmAppendTreeHint(leafIndex, utxo, siblingPath); - expect(trace.getAvmCircuitHints().l1ToL2MessageReadRequest.items).toEqual([expected]); + expect(trace.getAvmCircuitHints().l1ToL2MessageReads.items).toEqual([expected]); }); it('Should trace new L2ToL1 messages', () => { @@ -321,13 +321,13 @@ describe('Enqueued-call Side Effect Trace', () => { expect(parentHints.externalCalls.items).toEqual(childHints.externalCalls.items); expect(parentHints.contractInstances.items).toEqual(childHints.contractInstances.items); expect(parentHints.contractBytecodeHints.items).toEqual(childHints.contractBytecodeHints.items); - expect(parentHints.storageReadRequest.items).toEqual(childHints.storageReadRequest.items); - expect(parentHints.storageUpdateRequest.items).toEqual(childHints.storageUpdateRequest.items); - expect(parentHints.nullifierReadRequest.items).toEqual(childHints.nullifierReadRequest.items); - expect(parentHints.nullifierWriteHints.items).toEqual(childHints.nullifierWriteHints.items); - expect(parentHints.noteHashReadRequest.items).toEqual(childHints.noteHashReadRequest.items); - expect(parentHints.noteHashWriteRequest.items).toEqual(childHints.noteHashWriteRequest.items); - expect(parentHints.l1ToL2MessageReadRequest.items).toEqual(childHints.l1ToL2MessageReadRequest.items); + expect(parentHints.publicDataReads.items).toEqual(childHints.publicDataReads.items); + expect(parentHints.publicDataWrites.items).toEqual(childHints.publicDataWrites.items); + expect(parentHints.nullifierReads.items).toEqual(childHints.nullifierReads.items); + expect(parentHints.nullifierWrites.items).toEqual(childHints.nullifierWrites.items); + expect(parentHints.noteHashReads.items).toEqual(childHints.noteHashReads.items); + expect(parentHints.noteHashWrites.items).toEqual(childHints.noteHashWrites.items); + expect(parentHints.l1ToL2MessageReads.items).toEqual(childHints.l1ToL2MessageReads.items); }); }); }); diff --git a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts index 84e85adcd64..a7e24ac5520 100644 --- a/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts +++ b/yarn-project/simulator/src/public/enqueued_call_side_effect_trace.ts @@ -179,15 +179,13 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI this.avmCircuitHints.contractInstances.items.push(...forkedTrace.avmCircuitHints.contractInstances.items); this.avmCircuitHints.contractBytecodeHints.items.push(...forkedTrace.avmCircuitHints.contractBytecodeHints.items); - this.avmCircuitHints.storageReadRequest.items.push(...forkedTrace.avmCircuitHints.storageReadRequest.items); - this.avmCircuitHints.storageUpdateRequest.items.push(...forkedTrace.avmCircuitHints.storageUpdateRequest.items); - this.avmCircuitHints.nullifierReadRequest.items.push(...forkedTrace.avmCircuitHints.nullifierReadRequest.items); - this.avmCircuitHints.nullifierWriteHints.items.push(...forkedTrace.avmCircuitHints.nullifierWriteHints.items); - this.avmCircuitHints.noteHashReadRequest.items.push(...forkedTrace.avmCircuitHints.noteHashReadRequest.items); - this.avmCircuitHints.noteHashWriteRequest.items.push(...forkedTrace.avmCircuitHints.noteHashWriteRequest.items); - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push( - ...forkedTrace.avmCircuitHints.l1ToL2MessageReadRequest.items, - ); + this.avmCircuitHints.publicDataReads.items.push(...forkedTrace.avmCircuitHints.publicDataReads.items); + this.avmCircuitHints.publicDataWrites.items.push(...forkedTrace.avmCircuitHints.publicDataWrites.items); + this.avmCircuitHints.nullifierReads.items.push(...forkedTrace.avmCircuitHints.nullifierReads.items); + this.avmCircuitHints.nullifierWrites.items.push(...forkedTrace.avmCircuitHints.nullifierWrites.items); + this.avmCircuitHints.noteHashReads.items.push(...forkedTrace.avmCircuitHints.noteHashReads.items); + this.avmCircuitHints.noteHashWrites.items.push(...forkedTrace.avmCircuitHints.noteHashWrites.items); + this.avmCircuitHints.l1ToL2MessageReads.items.push(...forkedTrace.avmCircuitHints.l1ToL2MessageReads.items); } public getCounter() { @@ -211,7 +209,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI assert(leafPreimage.value.equals(value), 'Value mismatch when tracing in public data write'); } - this.avmCircuitHints.storageReadRequest.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); + this.avmCircuitHints.publicDataReads.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); this.log.debug(`SLOAD cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); this.incrementSideEffectCounter(); } @@ -245,7 +243,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI // New hinting const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.storageUpdateRequest.items.push( + this.avmCircuitHints.publicDataWrites.items.push( new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, insertionPath), ); @@ -264,7 +262,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI path: Fr[] = emptyNoteHashPath(), ) { // New Hinting - this.avmCircuitHints.noteHashReadRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashReads.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); // NOTE: counter does not increment for note hash checks (because it doesn't rely on pending note hashes) } @@ -282,7 +280,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI //const siloedNoteHash = siloNoteHash(contractAddress, noteHash); this.noteHashes.push(new NoteHash(noteHash, this.sideEffectCounter).scope(contractAddress)); this.log.debug(`NEW_NOTE_HASH cnt: ${this.sideEffectCounter}`); - this.avmCircuitHints.noteHashWriteRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashWrites.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); this.incrementSideEffectCounter(); } @@ -293,7 +291,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI lowLeafIndex: Fr = Fr.zero(), lowLeafPath: Fr[] = emptyNullifierPath(), ) { - this.avmCircuitHints.nullifierReadRequest.items.push( + this.avmCircuitHints.nullifierReads.items.push( new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath), ); this.log.debug(`NULLIFIER_EXISTS cnt: ${this.sideEffectCounter}`); @@ -314,7 +312,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI this.nullifiers.push(new Nullifier(siloedNullifier, this.sideEffectCounter, /*noteHash=*/ Fr.ZERO)); const lowLeafReadHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.nullifierWriteHints.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); + this.avmCircuitHints.nullifierWrites.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); this.log.debug(`NEW_NULLIFIER cnt: ${this.sideEffectCounter}`); this.incrementSideEffectCounter(); } @@ -327,7 +325,7 @@ export class PublicEnqueuedCallSideEffectTrace implements PublicSideEffectTraceI _exists: boolean, path: Fr[] = emptyL1ToL2MessagePath(), ) { - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); + this.avmCircuitHints.l1ToL2MessageReads.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); } public traceNewL2ToL1Message(contractAddress: AztecAddress, recipient: Fr, content: Fr) { diff --git a/yarn-project/simulator/src/public/fixtures/index.ts b/yarn-project/simulator/src/public/fixtures/index.ts index 512cbf93d30..68eec22d6e1 100644 --- a/yarn-project/simulator/src/public/fixtures/index.ts +++ b/yarn-project/simulator/src/public/fixtures/index.ts @@ -34,14 +34,10 @@ import { MerkleTrees } from '@aztec/world-state'; import { strict as assert } from 'assert'; -/** - * If assertionErrString is set, we expect a (non exceptional halting) revert due to a failing assertion and - * we check that the revert reason error contains this string. However, the circuit must correctly prove the - * execution. - */ export async function simulateAvmTestContractGenerateCircuitInputs( functionName: string, calldata: Fr[] = [], + expectRevert: boolean = false, assertionErrString?: string, ): Promise { const sender = AztecAddress.random(); @@ -80,13 +76,15 @@ export async function simulateAvmTestContractGenerateCircuitInputs( const avmResult = await simulator.simulate(tx); - if (assertionErrString == undefined) { + if (!expectRevert) { expect(avmResult.revertCode.isOK()).toBe(true); } else { // Explicit revert when an assertion failed. expect(avmResult.revertCode.isOK()).toBe(false); expect(avmResult.revertReason).toBeDefined(); - expect(avmResult.revertReason?.getMessage()).toContain(assertionErrString); + if (assertionErrString !== undefined) { + expect(avmResult.revertReason?.getMessage()).toContain(assertionErrString); + } } const avmCircuitInputs: AvmCircuitInputs = avmResult.avmProvingRequest.inputs; diff --git a/yarn-project/simulator/src/public/public_tx_context.ts b/yarn-project/simulator/src/public/public_tx_context.ts index f6fd8e7c9e1..94057597a18 100644 --- a/yarn-project/simulator/src/public/public_tx_context.ts +++ b/yarn-project/simulator/src/public/public_tx_context.ts @@ -2,6 +2,7 @@ import { type AvmProvingRequest, MerkleTreeId, type MerkleTreeReadOperations, + ProvingRequestType, type PublicExecutionRequest, type SimulationError, type Tx, @@ -9,16 +10,19 @@ import { TxHash, } from '@aztec/circuit-types'; import { + AppendOnlyTreeSnapshot, + AvmCircuitInputs, type AvmCircuitPublicInputs, Fr, Gas, type GasSettings, type GlobalVariables, - type Header, type PrivateToPublicAccumulatedData, type PublicCallRequest, + PublicCircuitPublicInputs, RevertCode, type StateReference, + TreeSnapshots, countAccumulatedItems, } from '@aztec/circuits.js'; import { type DebugLogger, createDebugLogger } from '@aztec/foundation/log'; @@ -26,13 +30,12 @@ import { type DebugLogger, createDebugLogger } from '@aztec/foundation/log'; import { strict as assert } from 'assert'; import { inspect } from 'util'; -import { type AvmFinalizedCallResult } from '../avm/avm_contract_call_result.js'; import { AvmPersistableStateManager } from '../avm/index.js'; import { DualSideEffectTrace } from './dual_side_effect_trace.js'; import { PublicEnqueuedCallSideEffectTrace, SideEffectArrayLengths } from './enqueued_call_side_effect_trace.js'; import { type WorldStateDB } from './public_db_sources.js'; import { PublicSideEffectTrace } from './side_effect_trace.js'; -import { generateAvmCircuitPublicInputs, generateAvmProvingRequest } from './transitional_adapters.js'; +import { generateAvmCircuitPublicInputs } from './transitional_adapters.js'; import { getCallRequestsByPhase, getExecutionRequestsByPhase } from './utils.js'; /** @@ -58,7 +61,6 @@ export class PublicTxContext { constructor( public readonly state: PhaseStateManager, private readonly globalVariables: GlobalVariables, - private readonly historicalHeader: Header, // FIXME(dbanks12): remove private readonly startStateReference: StateReference, private readonly startGasUsed: Gas, private readonly gasSettings: GasSettings, @@ -89,7 +91,7 @@ export class PublicTxContext { const previousAccumulatedDataArrayLengths = new SideEffectArrayLengths( /*publicDataWrites*/ 0, countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.noteHashes), - countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.nullifiers), + /*nullifiers=*/ 0, countAccumulatedItems(nonRevertibleAccumulatedDataFromPrivate.l2ToL1Msgs), /*unencryptedLogsHashes*/ 0, ); @@ -105,7 +107,6 @@ export class PublicTxContext { return new PublicTxContext( new PhaseStateManager(txStateManager), globalVariables, - tx.data.constants.historicalHeader, await db.getStateReference(), tx.data.gasUsed, tx.data.constants.txContext.gasSettings, @@ -301,11 +302,24 @@ export class PublicTxContext { */ private generateAvmCircuitPublicInputs(endStateReference: StateReference): AvmCircuitPublicInputs { assert(this.halted, 'Can only get AvmCircuitPublicInputs after tx execution ends'); - // TODO(dbanks12): use the state roots from ephemeral trees - endStateReference.partial.nullifierTree.root = this.state - .getActiveStateManager() - .merkleTrees.treeMap.get(MerkleTreeId.NULLIFIER_TREE)! - .getRoot(); + const ephemeralTrees = this.state.getActiveStateManager().merkleTrees.treeMap; + + const getAppendSnaphot = (id: MerkleTreeId) => { + const tree = ephemeralTrees.get(id)!; + return new AppendOnlyTreeSnapshot(tree.getRoot(), Number(tree.leafCount)); + }; + + const noteHashTree = getAppendSnaphot(MerkleTreeId.NOTE_HASH_TREE); + const nullifierTree = getAppendSnaphot(MerkleTreeId.NULLIFIER_TREE); + const publicDataTree = getAppendSnaphot(MerkleTreeId.PUBLIC_DATA_TREE); + + const endTreeSnapshots = new TreeSnapshots( + endStateReference.l1ToL2MessageTree, + noteHashTree, + nullifierTree, + publicDataTree, + ); + return generateAvmCircuitPublicInputs( this.trace, this.globalVariables, @@ -317,7 +331,7 @@ export class PublicTxContext { this.teardownCallRequests, this.nonRevertibleAccumulatedDataFromPrivate, this.revertibleAccumulatedDataFromPrivate, - endStateReference, + endTreeSnapshots, /*endGasUsed=*/ this.gasUsed, this.getTransactionFeeUnsafe(), this.revertCode, @@ -328,38 +342,17 @@ export class PublicTxContext { * Generate the proving request for the AVM circuit. */ generateProvingRequest(endStateReference: StateReference): AvmProvingRequest { - // TODO(dbanks12): Once we actually have tx-level proving, this will generate the entire - // proving request for the first time - this.avmProvingRequest!.inputs.output = this.generateAvmCircuitPublicInputs(endStateReference); - return this.avmProvingRequest!; - } - - // TODO(dbanks12): remove once AVM proves entire public tx - updateProvingRequest( - real: boolean, - phase: TxExecutionPhase, - fnName: string, - stateManager: AvmPersistableStateManager, - executionRequest: PublicExecutionRequest, - result: AvmFinalizedCallResult, - allocatedGas: Gas, - ) { - if (this.avmProvingRequest === undefined) { - // Propagate the very first avmProvingRequest of the tx for now. - // Eventually this will be the proof for the entire public portion of the transaction. - this.avmProvingRequest = generateAvmProvingRequest( - real, - fnName, - stateManager, - this.historicalHeader, - this.globalVariables, - executionRequest, - // TODO(dbanks12): do we need this return type unless we are doing an isolated call? - stateManager.trace.toPublicEnqueuedCallExecutionResult(result), - allocatedGas, - this.getTransactionFee(phase), - ); - } + const hints = this.trace.getAvmCircuitHints(); + return { + type: ProvingRequestType.PUBLIC_VM, + inputs: new AvmCircuitInputs( + 'public_dispatch', + [], + PublicCircuitPublicInputs.empty(), + hints, + this.generateAvmCircuitPublicInputs(endStateReference), + ), + }; } } diff --git a/yarn-project/simulator/src/public/public_tx_simulator.ts b/yarn-project/simulator/src/public/public_tx_simulator.ts index 7cf250cc1d5..7c7546e9b0c 100644 --- a/yarn-project/simulator/src/public/public_tx_simulator.ts +++ b/yarn-project/simulator/src/public/public_tx_simulator.ts @@ -290,17 +290,6 @@ export class PublicTxSimulator { `[AVM] Enqueued public call consumed ${gasUsed.l2Gas} L2 gas ending with ${result.gasLeft.l2Gas} L2 gas left.`, ); - // TODO(dbanks12): remove once AVM proves entire public tx - context.updateProvingRequest( - this.realAvmProvingRequests, - phase, - fnName, - stateManager, - executionRequest, - result, - allocatedGas, - ); - stateManager.traceEnqueuedCall(callRequest, executionRequest.args, result.reverted); if (result.reverted) { diff --git a/yarn-project/simulator/src/public/side_effect_trace.ts b/yarn-project/simulator/src/public/side_effect_trace.ts index 474e3ff155d..8e9f93256d0 100644 --- a/yarn-project/simulator/src/public/side_effect_trace.ts +++ b/yarn-project/simulator/src/public/side_effect_trace.ts @@ -138,7 +138,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New hinting - this.avmCircuitHints.storageReadRequest.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); + this.avmCircuitHints.publicDataReads.items.push(new AvmPublicDataReadTreeHint(leafPreimage, leafIndex, path)); this.log.debug(`SLOAD cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); this.incrementSideEffectCounter(); @@ -168,7 +168,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { // New hinting const readHint = new AvmPublicDataReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.storageUpdateRequest.items.push( + this.avmCircuitHints.publicDataWrites.items.push( new AvmPublicDataWriteTreeHint(readHint, newLeafPreimage, insertionPath), ); this.log.debug(`SSTORE cnt: ${this.sideEffectCounter} val: ${value} slot: ${slot}`); @@ -193,7 +193,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { new AvmKeyValueHint(/*key=*/ new Fr(leafIndex), /*value=*/ exists ? Fr.ONE : Fr.ZERO), ); // New Hinting - this.avmCircuitHints.noteHashReadRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashReads.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); // NOTE: counter does not increment for note hash checks (because it doesn't rely on pending note hashes) } @@ -210,7 +210,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { this.log.debug(`NEW_NOTE_HASH cnt: ${this.sideEffectCounter}`); // New Hinting - this.avmCircuitHints.noteHashWriteRequest.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); + this.avmCircuitHints.noteHashWrites.items.push(new AvmAppendTreeHint(leafIndex, noteHash, path)); this.incrementSideEffectCounter(); } @@ -237,7 +237,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New Hints - this.avmCircuitHints.nullifierReadRequest.items.push( + this.avmCircuitHints.nullifierReads.items.push( new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath), ); this.log.debug(`NULLIFIER_EXISTS cnt: ${this.sideEffectCounter}`); @@ -259,7 +259,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { this.nullifiers.push(new Nullifier(siloedNullifier, this.sideEffectCounter, /*noteHash=*/ Fr.ZERO)); // New hinting const lowLeafReadHint = new AvmNullifierReadTreeHint(lowLeafPreimage, lowLeafIndex, lowLeafPath); - this.avmCircuitHints.nullifierWriteHints.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); + this.avmCircuitHints.nullifierWrites.items.push(new AvmNullifierWriteTreeHint(lowLeafReadHint, insertionPath)); this.log.debug(`NEW_NULLIFIER cnt: ${this.sideEffectCounter}`); this.incrementSideEffectCounter(); } @@ -282,7 +282,7 @@ export class PublicSideEffectTrace implements PublicSideEffectTraceInterface { ); // New Hinting - this.avmCircuitHints.l1ToL2MessageReadRequest.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); + this.avmCircuitHints.l1ToL2MessageReads.items.push(new AvmAppendTreeHint(msgLeafIndex, msgHash, path)); // NOTE: counter does not increment for l1tol2 message checks (because it doesn't rely on pending messages) } diff --git a/yarn-project/simulator/src/public/transitional_adapters.ts b/yarn-project/simulator/src/public/transitional_adapters.ts index 63470f6fc18..09ec0094110 100644 --- a/yarn-project/simulator/src/public/transitional_adapters.ts +++ b/yarn-project/simulator/src/public/transitional_adapters.ts @@ -1,57 +1,28 @@ -import { type AvmProvingRequest, ProvingRequestType, type PublicExecutionRequest } from '@aztec/circuit-types'; import { - AvmCircuitInputs, - AvmCircuitPublicInputs, - AztecAddress, - ContractStorageRead, - ContractStorageUpdateRequest, - Fr, - Gas, + type AvmCircuitPublicInputs, + type Fr, + type Gas, type GasSettings, type GlobalVariables, - type Header, - L2ToL1Message, - LogHash, - MAX_ENQUEUED_CALLS_PER_CALL, - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - MAX_L2_TO_L1_MSGS_PER_CALL, MAX_L2_TO_L1_MSGS_PER_TX, - MAX_NOTE_HASHES_PER_CALL, MAX_NOTE_HASHES_PER_TX, - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - MAX_NULLIFIERS_PER_CALL, - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - MAX_PUBLIC_DATA_READS_PER_CALL, - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_TX, - MAX_UNENCRYPTED_LOGS_PER_CALL, - NoteHash, - Nullifier, PrivateToAvmAccumulatedData, PrivateToAvmAccumulatedDataArrayLengths, type PrivateToPublicAccumulatedData, PublicCallRequest, - PublicCircuitPublicInputs, PublicDataWrite, - PublicInnerCallRequest, - ReadRequest, - RevertCode, + type RevertCode, type StateReference, - TreeLeafReadRequest, TreeSnapshots, countAccumulatedItems, mergeAccumulatedData, } from '@aztec/circuits.js'; -import { computeNoteHashNonce, computeUniqueNoteHash, computeVarArgsHash, siloNoteHash } from '@aztec/circuits.js/hash'; +import { computeNoteHashNonce, computeUniqueNoteHash, siloNoteHash } from '@aztec/circuits.js/hash'; import { padArrayEnd } from '@aztec/foundation/collection'; import { assertLength } from '@aztec/foundation/serialize'; -import { AvmFinalizedCallResult } from '../avm/avm_contract_call_result.js'; -import { AvmExecutionEnvironment } from '../avm/avm_execution_environment.js'; -import { type AvmPersistableStateManager } from '../avm/journal/journal.js'; import { type PublicEnqueuedCallSideEffectTrace } from './enqueued_call_side_effect_trace.js'; -import { type EnqueuedPublicCallExecutionResult, type PublicFunctionCallResult } from './execution.js'; export function generateAvmCircuitPublicInputs( trace: PublicEnqueuedCallSideEffectTrace, @@ -64,7 +35,7 @@ export function generateAvmCircuitPublicInputs( teardownCallRequests: PublicCallRequest[], nonRevertibleAccumulatedDataFromPrivate: PrivateToPublicAccumulatedData, revertibleAccumulatedDataFromPrivate: PrivateToPublicAccumulatedData, - endStateReference: StateReference, + endTreeSnapshots: TreeSnapshots, endGasUsed: Gas, transactionFee: Fr, revertCode: RevertCode, @@ -75,12 +46,6 @@ export function generateAvmCircuitPublicInputs( startStateReference.partial.nullifierTree, startStateReference.partial.publicDataTree, ); - const endTreeSnapshots = new TreeSnapshots( - endStateReference.l1ToL2MessageTree, - endStateReference.partial.noteHashTree, - endStateReference.partial.nullifierTree, - endStateReference.partial.publicDataTree, - ); const avmCircuitPublicInputs = trace.toAvmCircuitPublicInputs( globalVariables, @@ -182,155 +147,3 @@ export function generateAvmCircuitPublicInputs( //console.log(`AvmCircuitPublicInputs:\n${inspect(avmCircuitPublicInputs)}`); return avmCircuitPublicInputs; } - -export function generateAvmProvingRequest( - real: boolean, - fnName: string, - stateManager: AvmPersistableStateManager, - historicalHeader: Header, - globalVariables: GlobalVariables, - executionRequest: PublicExecutionRequest, - result: EnqueuedPublicCallExecutionResult, - allocatedGas: Gas, - transactionFee: Fr, -): AvmProvingRequest { - const avmExecutionEnv = new AvmExecutionEnvironment( - executionRequest.callContext.contractAddress, - executionRequest.callContext.msgSender, - executionRequest.callContext.functionSelector, - /*contractCallDepth=*/ Fr.zero(), - transactionFee, - globalVariables, - executionRequest.callContext.isStaticCall, - executionRequest.args, - ); - - const avmCallResult = new AvmFinalizedCallResult(result.reverted, result.returnValues, result.endGasLeft); - - // Generate an AVM proving request - let avmProvingRequest: AvmProvingRequest; - if (real) { - const deprecatedFunctionCallResult = stateManager.trace.toPublicFunctionCallResult( - avmExecutionEnv, - /*startGasLeft=*/ allocatedGas, - Buffer.alloc(0), - avmCallResult, - fnName, - ); - const publicInputs = getPublicCircuitPublicInputs(historicalHeader, globalVariables, deprecatedFunctionCallResult); - avmProvingRequest = makeAvmProvingRequest(publicInputs, deprecatedFunctionCallResult); - } else { - avmProvingRequest = emptyAvmProvingRequest(); - } - return avmProvingRequest; -} - -function emptyAvmProvingRequest(): AvmProvingRequest { - return { - type: ProvingRequestType.PUBLIC_VM, - inputs: AvmCircuitInputs.empty(), - }; -} -function makeAvmProvingRequest(inputs: PublicCircuitPublicInputs, result: PublicFunctionCallResult): AvmProvingRequest { - return { - type: ProvingRequestType.PUBLIC_VM, - inputs: new AvmCircuitInputs( - result.functionName, - result.calldata, - inputs, - result.avmCircuitHints, - AvmCircuitPublicInputs.empty(), - ), - }; -} - -function getPublicCircuitPublicInputs( - historicalHeader: Header, - globalVariables: GlobalVariables, - result: PublicFunctionCallResult, -) { - const header = historicalHeader.clone(); // don't modify the original - header.state.partial.publicDataTree.root = Fr.zero(); // AVM doesn't check this yet - - return PublicCircuitPublicInputs.from({ - callContext: result.executionRequest.callContext, - proverAddress: AztecAddress.ZERO, - argsHash: computeVarArgsHash(result.executionRequest.args), - noteHashes: padArrayEnd( - result.noteHashes, - NoteHash.empty(), - MAX_NOTE_HASHES_PER_CALL, - `Too many note hashes. Got ${result.noteHashes.length} with max being ${MAX_NOTE_HASHES_PER_CALL}`, - ), - nullifiers: padArrayEnd( - result.nullifiers, - Nullifier.empty(), - MAX_NULLIFIERS_PER_CALL, - `Too many nullifiers. Got ${result.nullifiers.length} with max being ${MAX_NULLIFIERS_PER_CALL}`, - ), - l2ToL1Msgs: padArrayEnd( - result.l2ToL1Messages, - L2ToL1Message.empty(), - MAX_L2_TO_L1_MSGS_PER_CALL, - `Too many L2 to L1 messages. Got ${result.l2ToL1Messages.length} with max being ${MAX_L2_TO_L1_MSGS_PER_CALL}`, - ), - startSideEffectCounter: result.startSideEffectCounter, - endSideEffectCounter: result.endSideEffectCounter, - returnsHash: computeVarArgsHash(result.returnValues), - noteHashReadRequests: padArrayEnd( - result.noteHashReadRequests, - TreeLeafReadRequest.empty(), - MAX_NOTE_HASH_READ_REQUESTS_PER_CALL, - `Too many note hash read requests. Got ${result.noteHashReadRequests.length} with max being ${MAX_NOTE_HASH_READ_REQUESTS_PER_CALL}`, - ), - nullifierReadRequests: padArrayEnd( - result.nullifierReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_READ_REQUESTS_PER_CALL, - `Too many nullifier read requests. Got ${result.nullifierReadRequests.length} with max being ${MAX_NULLIFIER_READ_REQUESTS_PER_CALL}`, - ), - nullifierNonExistentReadRequests: padArrayEnd( - result.nullifierNonExistentReadRequests, - ReadRequest.empty(), - MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL, - `Too many nullifier non-existent read requests. Got ${result.nullifierNonExistentReadRequests.length} with max being ${MAX_NULLIFIER_NON_EXISTENT_READ_REQUESTS_PER_CALL}`, - ), - l1ToL2MsgReadRequests: padArrayEnd( - result.l1ToL2MsgReadRequests, - TreeLeafReadRequest.empty(), - MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL, - `Too many L1 to L2 message read requests. Got ${result.l1ToL2MsgReadRequests.length} with max being ${MAX_L1_TO_L2_MSG_READ_REQUESTS_PER_CALL}`, - ), - contractStorageReads: padArrayEnd( - result.contractStorageReads, - ContractStorageRead.empty(), - MAX_PUBLIC_DATA_READS_PER_CALL, - `Too many public data reads. Got ${result.contractStorageReads.length} with max being ${MAX_PUBLIC_DATA_READS_PER_CALL}`, - ), - contractStorageUpdateRequests: padArrayEnd( - result.contractStorageUpdateRequests, - ContractStorageUpdateRequest.empty(), - MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL, - `Too many public data update requests. Got ${result.contractStorageUpdateRequests.length} with max being ${MAX_PUBLIC_DATA_UPDATE_REQUESTS_PER_CALL}`, - ), - publicCallRequests: padArrayEnd( - result.publicCallRequests, - PublicInnerCallRequest.empty(), - MAX_ENQUEUED_CALLS_PER_CALL, - `Too many public call requests. Got ${result.publicCallRequests.length} with max being ${MAX_ENQUEUED_CALLS_PER_CALL}`, - ), - unencryptedLogsHashes: padArrayEnd( - result.unencryptedLogsHashes, - LogHash.empty(), - MAX_UNENCRYPTED_LOGS_PER_CALL, - `Too many unencrypted logs. Got ${result.unencryptedLogsHashes.length} with max being ${MAX_UNENCRYPTED_LOGS_PER_CALL}`, - ), - historicalHeader: header, - globalVariables: globalVariables, - startGasLeft: Gas.from(result.startGasLeft), - endGasLeft: Gas.from(result.endGasLeft), - transactionFee: result.transactionFee, - // TODO(@just-mitch): need better mapping from simulator to revert code. - revertCode: result.reverted ? RevertCode.APP_LOGIC_REVERTED : RevertCode.OK, - }); -}