diff --git a/.clang-tidy b/.clang-tidy index b274ccde..3fb2ff7b 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -19,6 +19,7 @@ Checks: '*, -google-runtime-references, -hicpp-no-assembler, -hicpp-special-member-functions,-warnings-as-errors, + -hicpp-signed-bitwise, -llvm-include-order, -llvmlibc-callee-namespace, -llvmlibc-implementation-in-namespace, diff --git a/include/mls/common.h b/include/mls/common.h index 6dff1b94..03cda851 100644 --- a/include/mls/common.h +++ b/include/mls/common.h @@ -216,11 +216,21 @@ std::vector transform(const Container& c, const UnaryOperation& op) { auto out = std::vector{}; - auto ins = std::inserter(out, out.begin()); + auto ins = std::back_inserter(out); std::transform(c.begin(), c.end(), ins, op); return out; } +template +std::vector +filter(const Container& c, const UnaryOperation& op) +{ + auto out = std::vector{}; + auto ins = std::back_inserter(out); + std::copy_if(c.begin(), c.end(), ins, op); + return out; +} + template bool any_of(const Container& c, const UnaryPredicate& pred) diff --git a/include/mls/messages.h b/include/mls/messages.h index c7442a8f..4042ee53 100644 --- a/include/mls/messages.h +++ b/include/mls/messages.h @@ -691,6 +691,47 @@ external_proposal(CipherSuite suite, uint32_t signer_index, const SignaturePrivateKey& sig_priv); +struct AnnotatedWelcome +{ + Welcome welcome; + + TreeSlice sender_membership_proof; + TreeSlice receiver_membership_proof; + + static AnnotatedWelcome from(Welcome welcome, + const TreeKEMPublicKey& tree, + LeafIndex sender, + LeafIndex joiner); + + TreeKEMPublicKey tree() const; + + TLS_SERIALIZABLE(welcome, sender_membership_proof, receiver_membership_proof); +}; + +struct AnnotatedCommit +{ + MLSMessage commit_message; + std::optional sender_membership_proof_before; + std::optional resolution_index; + + bytes tree_hash_after; + TreeSlice sender_membership_proof_after; + TreeSlice receiver_membership_proof_after; + + static AnnotatedCommit from(LeafIndex receiver, + const std::vector& proposals, + const MLSMessage& commit_message, + const TreeKEMPublicKey& tree_before, + const TreeKEMPublicKey& tree_after); + + TLS_SERIALIZABLE(commit_message, + sender_membership_proof_before, + resolution_index, + tree_hash_after, + sender_membership_proof_after, + receiver_membership_proof_after); +}; + } // namespace MLS_NAMESPACE namespace MLS_NAMESPACE::tls { diff --git a/include/mls/state.h b/include/mls/state.h index d1d77ea7..ae5dc1ea 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -19,9 +19,16 @@ struct RosterIndex : public UInt32 struct CommitOpts { + // Include these proposals in the commit by value std::vector extra_proposals; - bool inline_tree; - bool force_path; + + // Send a ratchet_tree extension in the Welcome + bool inline_tree = false; + + // Send an UpdatePath even if none is required + bool force_path = false; + + // Update the committer's LeafNode in the following way LeafNodeOptions leaf_node_opts; }; @@ -127,6 +134,14 @@ class State std::optional handle(const ValidatedContent& content_auth, std::optional cached_state); + /// + /// Light MLS + /// + void implant_tree_slice(const TreeSlice& slice); + State handle(const AnnotatedCommit& annotated_commit); + bool is_full_client() const { return _tree.is_complete(); } + void upgrade_to_full_client(TreeKEMPublicKey tree); + /// /// PSK management /// @@ -327,6 +342,7 @@ class State CommitMaterials prepare_commit(const bytes& leaf_secret, const std::optional& opts, const CommitParams& params) const; + GroupInfo group_info(bool external_pub, bool inline_tree) const; Welcome welcome(bool inline_tree, const std::vector& psks, const std::vector& joiners, diff --git a/include/mls/tree_math.h b/include/mls/tree_math.h index 19e12d7b..8d9d674a 100644 --- a/include/mls/tree_math.h +++ b/include/mls/tree_math.h @@ -99,8 +99,8 @@ struct NodeIndex : public UInt32 // of `ancestor` that is not in the direct path of this node. NodeIndex sibling(NodeIndex ancestor) const; - std::vector dirpath(LeafCount n); - std::vector copath(LeafCount n); + std::vector dirpath(LeafCount n) const; + std::vector copath(LeafCount n) const; uint32_t level() const; }; diff --git a/include/mls/treekem.h b/include/mls/treekem.h index 71e12335..7825f43f 100644 --- a/include/mls/treekem.h +++ b/include/mls/treekem.h @@ -59,6 +59,18 @@ struct OptionalNode TLS_SERIALIZABLE(node) }; +struct TreeSlice +{ + LeafIndex leaf_index; + LeafCount n_leaves; + std::vector direct_path_nodes; + std::vector copath_hashes; + + bytes tree_hash(CipherSuite suite) const; + + TLS_SERIALIZABLE(leaf_index, n_leaves, direct_path_nodes, copath_hashes); +}; + struct TreeKEMPublicKey; struct TreeKEMPrivateKey @@ -113,15 +125,19 @@ struct TreeKEMPrivateKey void implant(const TreeKEMPublicKey& pub, NodeIndex start, const bytes& path_secret); + void implant_matching(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret); }; struct TreeKEMPublicKey { CipherSuite suite; LeafCount size{ 0 }; - std::vector nodes; + std::map nodes; explicit TreeKEMPublicKey(CipherSuite suite); + TreeKEMPublicKey(CipherSuite suite, const TreeSlice& slice); TreeKEMPublicKey() = default; TreeKEMPublicKey(const TreeKEMPublicKey& other) = default; @@ -148,14 +164,29 @@ struct TreeKEMPublicKey const bytes& get_hash(NodeIndex index); bytes root_hash() const; + bool parent_hash_valid(LeafIndex from) const; bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const; bool parent_hash_valid() const; + bool is_complete() const; bool has_leaf(LeafIndex index) const; std::optional find(const LeafNode& leaf) const; std::optional leaf_node(LeafIndex index) const; std::vector resolve(NodeIndex index) const; + TreeSlice extract_slice(LeafIndex leaf) const; + void implant_slice(const TreeSlice& slice); + std::tuple slice_path(UpdatePath path, + LeafIndex from, + LeafIndex to) const; + + struct AncestorIndex + { + size_t ancestor_node_index; + NodeIndex resolution_node; + }; + AncestorIndex ancestor_index(LeafIndex to, LeafIndex from) const; + struct DecapCoords { size_t ancestor_node_index; @@ -171,6 +202,13 @@ struct TreeKEMPublicKey bool all_leaves(const UnaryPredicate& pred) const { for (LeafIndex i{ 0 }; i < size; i.val++) { + // Only test known nodes + // XXX(RLB) This could be dangerous, since it allows for nodes to fail the + // predicate as long as they are unknown. + if (nodes.count(NodeIndex(i)) == 0) { + continue; + } + const auto& node = node_at(i); if (node.blank()) { continue; @@ -201,8 +239,8 @@ struct TreeKEMPublicKey return false; } - using FilteredDirectPath = - std::vector>>; + using FilteredDirectPathEntry = std::tuple>; + using FilteredDirectPath = std::vector; FilteredDirectPath filtered_direct_path(NodeIndex index) const; void truncate(); @@ -225,6 +263,9 @@ struct TreeKEMPublicKey void clear_hash_path(LeafIndex index); bool has_parent_hash(NodeIndex child, const bytes& target_ph) const; + bool parent_hash_valid(LeafIndex from, + const UpdatePath& path, + const FilteredDirectPath& fdp) const; bytes parent_hash(const ParentNode& parent, NodeIndex copath_child) const; std::vector parent_hashes( @@ -245,6 +286,8 @@ struct TreeKEMPublicKey bool exists_in_tree(const SignaturePublicKey& key, std::optional except) const; + void implant_slice_unchecked(const TreeSlice& slice); + OptionalNode blank_node; friend struct TreeKEMPrivateKey; diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 31f9546e..e0fc4ef8 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(bytes) add_subdirectory(hpke) add_subdirectory(tls_syntax) +add_subdirectory(mls_ds) add_subdirectory(mls_vectors) diff --git a/lib/bytes/src/bytes.cpp b/lib/bytes/src/bytes.cpp index 2bb2e93e..55aeb756 100644 --- a/lib/bytes/src/bytes.cpp +++ b/lib/bytes/src/bytes.cpp @@ -30,8 +30,6 @@ bytes::operator==(const std::vector& other) const unsigned char diff = 0; for (size_t i = 0; i < size; ++i) { - // Not sure why the linter thinks `diff` is signed - // NOLINTNEXTLINE(hicpp-signed-bitwise) diff |= (_data.at(i) ^ other.at(i)); } return (diff == 0); diff --git a/lib/hpke/src/certificate.cpp b/lib/hpke/src/certificate.cpp index 3d9737f8..596984b7 100644 --- a/lib/hpke/src/certificate.cpp +++ b/lib/hpke/src/certificate.cpp @@ -404,7 +404,6 @@ Certificate::parse_pem(const bytes& pem) auto x509 = make_typed_unique( PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); if (!x509) { - // NOLINTNEXTLINE(hicpp-signed-bitwise) auto err = ERR_GET_REASON(ERR_peek_last_error()); if (err == PEM_R_NO_START_LINE) { // No more objects to read diff --git a/lib/hpke/src/rsa.cpp b/lib/hpke/src/rsa.cpp index 2eea958d..69319078 100644 --- a/lib/hpke/src/rsa.cpp +++ b/lib/hpke/src/rsa.cpp @@ -31,7 +31,6 @@ RSASignature::generate_key_pair(size_t bits) throw openssl_error(); } - // NOLINTNEXTLINE(hicpp-signed-bitwise) if (EVP_PKEY_CTX_set_rsa_keygen_bits(ctx.get(), static_cast(bits)) <= 0) { throw openssl_error(); diff --git a/lib/mls_ds/CMakeLists.txt b/lib/mls_ds/CMakeLists.txt new file mode 100644 index 00000000..b6440630 --- /dev/null +++ b/lib/mls_ds/CMakeLists.txt @@ -0,0 +1,38 @@ +set(CURRENT_LIB_NAME mls_ds) + +### +### Library Config +### + +file(GLOB_RECURSE LIB_HEADERS CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/include/*.h") +file(GLOB_RECURSE LIB_SOURCES CONFIGURE_DEPENDS "${CMAKE_CURRENT_SOURCE_DIR}/src/*.cpp") + +add_library(${CURRENT_LIB_NAME} ${LIB_HEADERS} ${LIB_SOURCES}) +add_dependencies(${CURRENT_LIB_NAME} mlspp) +target_link_libraries(${CURRENT_LIB_NAME} mlspp bytes tls_syntax) +target_include_directories(${CURRENT_LIB_NAME} + PUBLIC + $ + $ + $ +) + +### +### Install +### + +install(TARGETS ${CURRENT_LIB_NAME} EXPORT mlspp-targets) +install( + DIRECTORY + include/ + DESTINATION + ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME} +) + +### +### Tests +### + +if (TESTING) + add_subdirectory(test) +endif() diff --git a/lib/mls_ds/README.md b/lib/mls_ds/README.md new file mode 100644 index 00000000..4c20c742 --- /dev/null +++ b/lib/mls_ds/README.md @@ -0,0 +1,11 @@ +# MLS Delivery Service Tools + +This library provides tools that can be convenient for an MLS Delivery Service +(DS). We do not cover the actual delivery mechanics, but instead on more +advanced functions where the DS needs to be aware of the internals of the MLS +protocol. + +For example, it is sometimes useful for the DS to maintain a view of a group's +ratchet tree based on seeing the group's Commits (sent as PublicMessage). To do +this, the DS needs to parse commits and know how to apply them to the tree. +The `TreeFollower` class provided in this library implements this functionality. diff --git a/lib/mls_ds/include/mls_ds/tree_follower.h b/lib/mls_ds/include/mls_ds/tree_follower.h new file mode 100644 index 00000000..899c828f --- /dev/null +++ b/lib/mls_ds/include/mls_ds/tree_follower.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include +#include + +namespace MLS_NAMESPACE::mls_ds { + +using namespace MLS_NAMESPACE; + +class TreeFollower +{ +public: + // Construct a one-member tree + TreeFollower(const KeyPackage& key_package); + + // Import a tree as a starting point for future updates + TreeFollower(TreeKEMPublicKey tree); + + // Update the tree with a set of proposals applied by a commit + void update(const MLSMessage& commit_message, + const std::vector& extra_proposals); + + // Accessors + CipherSuite cipher_suite() const { return _suite; } + const TreeKEMPublicKey& tree() const { return _tree; } + +private: + CipherSuite _suite; + TreeKEMPublicKey _tree; +}; + +} // namespace MLS_NAMESPACE::mls_ds diff --git a/lib/mls_ds/src/tree_follower.cpp b/lib/mls_ds/src/tree_follower.cpp new file mode 100644 index 00000000..111b078e --- /dev/null +++ b/lib/mls_ds/src/tree_follower.cpp @@ -0,0 +1,166 @@ +#include + +namespace MLS_NAMESPACE::mls_ds { + +/// +/// Resolving & Applying Proposals +/// + +using SenderAndProposal = std::tuple; + +static std::vector +resolve(CipherSuite suite, + Sender commit_sender, + const std::vector& proposals, + const std::vector& extra_proposals) +{ + auto cache = std::map{}; + for (const auto& proposal_msg : extra_proposals) { + const auto& public_message = var::get(proposal_msg.message); + const auto content_auth = public_message.authenticated_content(); + const auto sender = content_auth.content.sender; + const auto proposal = var::get(content_auth.content.content); + + const auto ref = suite.ref(content_auth); + cache.insert_or_assign(ref, std::make_tuple(sender, proposal)); + } + + // Resolve the proposals vector + return stdx::transform( + proposals, [&](const auto& proposal_or_ref) { + static const auto resolver = + overloaded{ [&](const Proposal& proposal) -> SenderAndProposal { + return { commit_sender, proposal }; + }, + [&](const ProposalRef& ref) -> SenderAndProposal { + return cache.at(ref); + } }; + return var::visit(resolver, proposal_or_ref.content); + }); +} + +static void +apply(TreeKEMPublicKey& tree, Sender /* sender */, const Add& add) +{ + tree.add_leaf(add.key_package.leaf_node); +} + +static void +apply(TreeKEMPublicKey& tree, Sender /* sender */, const Remove& remove) +{ + tree.blank_path(remove.removed); +} + +static void +apply(TreeKEMPublicKey& tree, Sender sender, const Update& update) +{ + const auto sender_index = var::get(sender.sender).sender; + tree.update_leaf(sender_index, update.leaf_node); +} + +static void +apply(TreeKEMPublicKey& /* tree */, + Sender /* sender */, + const PreSharedKey& /* pre_shared_key */) +{ +} + +static void +apply(TreeKEMPublicKey& /* tree */, + Sender /* sender */, + const ReInit& /* re_init */) +{ +} + +static void +apply(TreeKEMPublicKey& /* tree */, + Sender /* sender */, + const ExternalInit& /* external_init */) +{ +} + +static void +apply(TreeKEMPublicKey& /* tree */, + Sender /* sender */, + const GroupContextExtensions& /* gce */) +{ +} + +static void +apply(TreeKEMPublicKey& tree, + const std::vector& proposals, + Proposal::Type proposal_type) +{ + for (const auto& [sender_, proposal] : proposals) { + const auto& sender = sender_; + if (proposal.proposal_type() != proposal_type) { + continue; + } + + std::visit([&](const auto& pr) { apply(tree, sender, pr); }, + proposal.content); + } +} + +static void +apply(TreeKEMPublicKey& tree, + CipherSuite suite, + Sender commit_sender, + const std::vector& proposals, + const std::vector& extra_proposals) +{ + const auto resolved = + resolve(suite, commit_sender, proposals, extra_proposals); + + apply(tree, resolved, ProposalType::update); + apply(tree, resolved, ProposalType::remove); + apply(tree, resolved, ProposalType::add); +} + +/// +/// TreeFollower +/// + +TreeFollower::TreeFollower(const KeyPackage& key_package) + : _suite(key_package.cipher_suite) + , _tree(key_package.cipher_suite) +{ + _tree.add_leaf(key_package.leaf_node); + _tree.set_hash_all(); +} + +TreeFollower::TreeFollower(TreeKEMPublicKey tree) + : _suite(tree.suite) + , _tree(std::move(tree)) +{ +} + +void +TreeFollower::update(const mls::MLSMessage& commit_message, + const std::vector& extra_proposals) +{ + // Unwrap the Commit + const auto& commit_public_message = + var::get(commit_message.message); + const auto commit_auth_content = + commit_public_message.authenticated_content(); + const auto group_content = commit_auth_content.content; + const auto& commit = + var::get(commit_auth_content.content.content); + + // Apply proposals + apply(_tree, _suite, group_content.sender, commit.proposals, extra_proposals); + _tree.truncate(); + _tree.set_hash_all(); + + // Merge the update path + if (commit.path) { + const auto sender = + var::get(group_content.sender.sender); + const auto from = LeafIndex(sender.sender); + const auto& path = opt::get(commit.path); + _tree.merge(from, path); + } +} + +} // namespace MLS_NAMESPACE::mls_ds diff --git a/lib/mls_ds/test/CMakeLists.txt b/lib/mls_ds/test/CMakeLists.txt new file mode 100644 index 00000000..e70e9dcb --- /dev/null +++ b/lib/mls_ds/test/CMakeLists.txt @@ -0,0 +1,11 @@ +set(TEST_APP_NAME "${CURRENT_LIB_NAME}_test") + +# Test Binary +file(GLOB TEST_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) + +add_executable(${TEST_APP_NAME} ${TEST_SOURCES}) +add_dependencies(${TEST_APP_NAME} ${CURRENT_LIB_NAME} bytes tls_syntax) +target_link_libraries(${TEST_APP_NAME} PRIVATE ${CURRENT_LIB_NAME} Catch2::Catch2WithMain) + +# Enable CTest +catch_discover_tests(${TEST_APP_NAME}) diff --git a/lib/mls_ds/test/tree_follower.cpp b/lib/mls_ds/test/tree_follower.cpp new file mode 100644 index 00000000..4d525458 --- /dev/null +++ b/lib/mls_ds/test/tree_follower.cpp @@ -0,0 +1,138 @@ +#include +#include +#include + +using namespace MLS_NAMESPACE; +using namespace MLS_NAMESPACE::mls_ds; + +class TreeFollowerTest +{ +public: + TreeFollowerTest() + { + for (size_t i = 0; i < group_size; i += 1) { + auto [init_priv, leaf_priv, identity_priv, key_package] = make_client(); + init_privs.push_back(init_priv); + leaf_privs.push_back(leaf_priv); + identity_privs.push_back(identity_priv); + key_packages.push_back(key_package); + } + } + +protected: + const CipherSuite suite{ CipherSuite::ID::P256_AES128GCM_SHA256_P256 }; + + const size_t group_size = 5; + const bytes group_id = { 0, 1, 2, 3 }; + const bytes user_id = { 4, 5, 6, 7 }; + const bytes test_aad = from_hex("01020304"); + const bytes test_message = from_hex("11121314"); + const std::string export_label = "test"; + const bytes export_context = from_hex("05060708"); + const size_t export_size = 16; + + std::vector init_privs; + std::vector leaf_privs; + std::vector identity_privs; + std::vector key_packages; + std::vector states; + + bytes fresh_secret() const { return random_bytes(suite.secret_size()); } + + std::tuple + make_client() + { + auto identity_priv = SignaturePrivateKey::generate(suite); + auto credential = Credential::basic(user_id); + auto init_priv = HPKEPrivateKey::generate(suite); + auto leaf_priv = HPKEPrivateKey::generate(suite); + auto leaf_node = LeafNode{ suite, + leaf_priv.public_key, + identity_priv.public_key, + credential, + Capabilities::create_default(), + Lifetime::create_default(), + {}, + identity_priv }; + auto key_package = + KeyPackage{ suite, init_priv.public_key, leaf_node, {}, identity_priv }; + + return std::make_tuple(init_priv, leaf_priv, identity_priv, key_package); + } +}; + +TEST_CASE_METHOD(TreeFollowerTest, "DS Follows Tree through Group Lifecycle") +{ + // Initialize a one-member group and a tree follower + states.emplace_back(group_id, + suite, + leaf_privs[0], + identity_privs[0], + key_packages[0].leaf_node, + ExtensionList{}); + + auto follower = TreeFollower(key_packages[0]); + + REQUIRE(follower.cipher_suite() == states[0].cipher_suite()); + REQUIRE(follower.tree() == states[0].tree()); + + // Add the remaining members in a single commit + auto adds = std::vector{}; + for (size_t i = 1; i < group_size; i += 1) { + adds.push_back(states[0].add_proposal(key_packages[i])); + } + + auto [commit1, welcome1, new_state1] = + states[0].commit(fresh_secret(), CommitOpts{ adds, true, false, {} }, {}); + silence_unused(commit1); + states[0] = new_state1; + + for (size_t i = 1; i < group_size; i += 1) { + states.push_back({ init_privs[i], + leaf_privs[i], + identity_privs[i], + key_packages[i], + welcome1, + std::nullopt, + {} }); + } + + follower.update(commit1, {}); + REQUIRE(follower.tree() == states[0].tree()); + + // Members 1..4 update, member 0 commits + auto updates = std::vector{}; + for (size_t i = 1; i < group_size; i += 1) { + const auto leaf_priv = HPKEPrivateKey::generate(suite); + const auto update = states[i].update(leaf_priv, {}, {}); + updates.push_back(update); + + for (auto& state : states) { + state.handle(update); + } + } + + auto [commit2, welcome2, new_state2] = + states[0].commit(fresh_secret(), {}, {}); + states[0] = new_state2; + for (size_t i = 1; i < group_size; i += 1) { + states[i] = opt::get(states[i].handle(commit2)); + } + + follower.update(commit2, updates); + REQUIRE(follower.tree() == states[0].tree()); + + // Member 4 removes members 0..3 one by one + for (uint32_t i = 1; i < group_size - 1; i += 1) { + const auto remove = states[group_size - 1].remove_proposal(LeafIndex{ i }); + + auto [commit, welcome, new_state] = states[group_size - 1].commit( + fresh_secret(), CommitOpts{ { remove }, false, false, {} }, {}); + silence_unused(commit); + silence_unused(welcome); + states[group_size - 1] = new_state; + + follower.update(commit, {}); + REQUIRE(follower.tree() == states[group_size - 1].tree()); + } +} diff --git a/lib/tls_syntax/src/tls_syntax.cpp b/lib/tls_syntax/src/tls_syntax.cpp index dccfe16f..4dfae641 100644 --- a/lib/tls_syntax/src/tls_syntax.cpp +++ b/lib/tls_syntax/src/tls_syntax.cpp @@ -76,8 +76,6 @@ operator>>(istream& in, bool& data) uint8_t val = 0; in >> val; - // Linter thinks uint8_t is signed (?) - // NOLINTNEXTLINE(hicpp-signed-bitwise) if ((val & 0xFE) != 0) { throw ReadError("Malformed boolean"); } diff --git a/src/messages.cpp b/src/messages.cpp index 8d483193..ed9352e6 100644 --- a/src/messages.cpp +++ b/src/messages.cpp @@ -213,7 +213,142 @@ Welcome::group_info_key_nonce(CipherSuite suite, return { std::move(key), std::move(nonce) }; } -// Commit +/// +/// AnnotatedWelcome +/// + +AnnotatedWelcome +AnnotatedWelcome::from(Welcome welcome, + const TreeKEMPublicKey& tree, + LeafIndex sender, + LeafIndex joiner) +{ + return { + std::move(welcome), + tree.extract_slice(sender), + tree.extract_slice(joiner), + }; +} + +TreeKEMPublicKey +AnnotatedWelcome::tree() const +{ + auto tree = TreeKEMPublicKey{ welcome.cipher_suite, sender_membership_proof }; + tree.implant_slice(receiver_membership_proof); + return tree; +} + +/// +/// AnnotatedCommit +/// + +AnnotatedCommit +AnnotatedCommit::from(LeafIndex receiver, + const std::vector& proposals, + const MLSMessage& commit_message, + const TreeKEMPublicKey& tree_before, + const TreeKEMPublicKey& tree_after) +{ + // Unpack the commit message + // XXX(RLB) There's some cheating here using authenticated_content() + const auto public_message = var::get(commit_message.message); + const auto content_auth = public_message.authenticated_content(); + switch (content_auth.content.sender.sender_type()) { + case SenderType::member: + case SenderType::new_member_commit: + break; + + default: + throw ProtocolError("Invalid commit sender type"); + } + + const auto& commit = var::get(content_auth.content.content); + + // Compute the list of committed proposals + auto cache = std::map{}; + for (const auto& proposal_msg : proposals) { + const auto& proposal_public_message = + var::get(proposal_msg.message); + const auto proposal_content_auth = + proposal_public_message.authenticated_content(); + const auto proposal = + var::get(proposal_content_auth.content.content); + + const auto ref = tree_before.suite.ref(content_auth); + cache.insert_or_assign(ref, proposal); + } + + const auto committed_proposals = + stdx::transform(commit.proposals, [&](const auto& p_or_r) { + const auto resolve = overloaded{ + [&](const ProposalRef& r) { return cache.at(r); }, + [](const Proposal& p) { return p; }, + }; + return var::visit(resolve, p_or_r.content); + }); + + // Identify the sender + const auto& sender_var = content_auth.content.sender.sender; + const auto external_commit = + var::holds_alternative(sender_var); + + auto sender = LeafIndex{ 0 }; + if (external_commit) { + // The committer's LeafNode is in the commit path + const auto& path = opt::get(commit.path); + sender = opt::get(tree_after.find(path.leaf_node)); + } else { + // Must be member sender + sender = var::get(sender_var).sender; + } + + // Extract the appropriate membership proofs + const auto tree_hash_after = tree_after.root_hash(); + + auto sender_membership_proof_before = std::optional{}; + if (!external_commit) { + sender_membership_proof_before = tree_before.extract_slice(sender); + } + + const auto sender_membership_proof_after = tree_after.extract_slice(sender); + const auto receiver_membership_proof_after = + tree_after.extract_slice(receiver); + + // If there is a path, identify which node the receiver should decrypt + auto resolution_index = std::optional{}; + if (commit.path) { + // Find where the joiners are + const auto add_proposals = + stdx::filter(committed_proposals, [](const auto& p) { + return p.proposal_type() == ProposalType::add; + }); + + const auto joiner_locations = + stdx::transform(add_proposals, [&](const auto& p) { + const auto& add = var::get(p.content); + const auto maybe_loc = tree_after.find(add.key_package.leaf_node); + return opt::get(maybe_loc); + }); + + // Compute the required coordinates + const auto coords = + tree_after.decap_coords(receiver, sender, joiner_locations); + resolution_index = static_cast(coords.resolution_node_index); + } + + return { + commit_message, + sender_membership_proof_before, + resolution_index, + tree_hash_after, + sender_membership_proof_after, + receiver_membership_proof_after, + }; +} + +/// +/// Commit +/// std::optional Commit::valid_external() const { diff --git a/src/state.cpp b/src/state.cpp index 86cd9a31..51e345db 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -29,13 +29,13 @@ State::State(bytes group_id, } _index = _tree.add_leaf(leaf_node); - _tree.set_hash_all(); _tree_priv = TreeKEMPrivateKey::solo(suite, _index, std::move(enc_priv)); if (!_tree_priv.consistent(_tree)) { throw InvalidParameterError("LeafNode inconsistent with private key"); } // XXX(RLB): Convert KeyScheduleEpoch to take GroupContext? + _tree.set_hash_all(); auto ctx = tls::marshal(group_context()); _key_schedule = KeyScheduleEpoch(_suite, random_bytes(_suite.secret_size()), ctx); @@ -52,6 +52,7 @@ State::import_tree(const bytes& tree_hash, { auto tree = TreeKEMPublicKey(_suite); auto maybe_tree_extn = extensions.find(); + if (external) { tree = opt::get(external); } else if (maybe_tree_extn) { @@ -73,6 +74,12 @@ State::import_tree(const bytes& tree_hash, bool State::validate_tree() const { + // If we don't have a full tree, then we can't verify any properties + // TODO(RLB) We can in fact verify that the leaf node signatures are valid. + if (!_tree.is_complete()) { + return true; + } + // The functionality here is somewhat duplicative of State::valid(const // LeafNode&). Simply calling that method, however, would result in this // method having quadratic scaling, since each call to valid() does a linear @@ -753,8 +760,7 @@ State::welcome(bool inline_tree, const std::vector& joiners, const std::vector>& path_secrets) const { - // TODO(RLB) Suppress external_pub in this GroupInfo - auto group_info_obj = group_info(inline_tree); + auto group_info_obj = group_info(false, inline_tree); auto welcome = Welcome{ _suite, _key_schedule.joiner_secret, psks, group_info_obj }; @@ -1097,6 +1103,166 @@ State::ratchet(TreeKEMPublicKey new_tree, return next; } +/// +/// Light MLS +/// + +void +State::implant_tree_slice(const TreeSlice& slice) +{ + _tree.implant_slice(slice); +} + +State +State::handle(const AnnotatedCommit& annotated_commit) +{ + const auto external_commit = + !bool(annotated_commit.sender_membership_proof_before); + if (!external_commit) { + // If this is not an external commit, verify that the sender is a member of + // the group ... + const auto& proof = + opt::get(annotated_commit.sender_membership_proof_before); + if (proof.tree_hash(_suite) != _tree.root_hash()) { + throw ProtocolError("Invalid sender membership proof"); + } + + // ... and verify that the same leaf is proved before and after + if (proof.leaf_index != + annotated_commit.sender_membership_proof_after.leaf_index) { + throw ProtocolError("Inconsistent sender membership proofs before/after"); + } + + // Then remember that the sender is a part of the tree + _tree.implant_slice(proof); + } + + // Verify the membership proofs are consistent + const auto tree_hash_after = + annotated_commit.sender_membership_proof_after.tree_hash(_suite); + if (tree_hash_after != + annotated_commit.receiver_membership_proof_after.tree_hash(_suite)) { + throw ProtocolError("Inconsistent sender and receiver membership proofs"); + } + + if (_index != annotated_commit.receiver_membership_proof_after.leaf_index) { + throw ProtocolError("Receiver membership proof is not for this node"); + } + + // XXX(RLB) This could fail if the receiver could have sent Update + const auto my_leaf = opt::get(_tree.leaf_node(_index)); + const auto& proof_node = opt::get( + annotated_commit.receiver_membership_proof_after.direct_path_nodes[0].node); + const auto& proof_leaf = var::get(proof_node.node); + if (my_leaf != proof_leaf) { + throw ProtocolError("Incorrect leaf node in receiver membership proof"); + } + + // Unwrap the commit + const auto val_content = unwrap(annotated_commit.commit_message); + const auto& content_auth = val_content.authenticated_content(); + const auto& content = content_auth.content; + const auto& commit = var::get(content.content); + + const auto sender_location = + annotated_commit.sender_membership_proof_after.leaf_index; + if (var::holds_alternative(content.sender.sender) && + var::get(content.sender.sender).sender != sender_location) { + throw ProtocolError("Incorrect commit sender"); + } + + // If this is an external commit, extract the forced init secret + auto force_init_secret = std::optional{}; + if (var::holds_alternative(content.sender.sender)) { + const auto kem_output = commit.valid_external(); + if (!kem_output) { + throw ProtocolError("Invalid external commit"); + } + + force_init_secret = + _key_schedule.receive_external_init(opt::get(kem_output)); + } + + // Update the GroupContext + auto new_tree = + TreeKEMPublicKey(_suite, annotated_commit.sender_membership_proof_after); + new_tree.implant_slice(annotated_commit.receiver_membership_proof_after); + + // Identify the encrypted path secret and how to decrypt it + auto path_secret_decrypt_node = std::optional{}; + auto encrypted_path_secret = std::optional{}; + if (commit.path) { + if (!annotated_commit.resolution_index) { + throw ProtocolError("Commit path present without resolution index"); + } + + const auto& path = opt::get(commit.path); + + if (!valid(path.leaf_node, LeafNodeSource::commit, sender_location)) { + throw ProtocolError("Commit path has invalid leaf node"); + } + + if (!new_tree.parent_hash_valid(sender_location)) { + throw ProtocolError("Commit path has invalid parent hash"); + } + + const auto resolution_node_index = + opt::get(annotated_commit.resolution_index); + const auto coords = new_tree.ancestor_index(_index, sender_location); + + path_secret_decrypt_node = coords.resolution_node; + encrypted_path_secret = path.nodes.at(coords.ancestor_node_index) + .encrypted_path_secret.at(resolution_node_index); + } + + // Update the transcript hash + const auto new_confirmed_transcript_hash = _transcript_hash.new_confirmed( + content_auth.confirmed_transcript_hash_input()); + const auto new_confirmation_tag = + opt::get(content_auth.auth.confirmation_tag); + + // Identify GCE or PSK proposals + const auto proposals = must_resolve(commit.proposals, sender_location); + auto extensions = _extensions; + auto psk_ids = std::vector{}; + for (const auto& p : proposals) { + if (p.proposal.proposal_type() == ProposalType::psk) { + const auto& psk_proposal = var::get(p.proposal.content); + psk_ids.push_back(psk_proposal.psk); + } + + if (p.proposal.proposal_type() == ProposalType::group_context_extensions) { + const auto& gce_proposal = + var::get(p.proposal.content); + extensions = gce_proposal.group_context_extensions; + } + } + + const auto psks = resolve(psk_ids); + + return ratchet(std::move(new_tree), + sender_location, + path_secret_decrypt_node, + encrypted_path_secret, + extensions, + psks, + force_init_secret, + new_confirmed_transcript_hash, + new_confirmation_tag); +} + +void +State::upgrade_to_full_client(TreeKEMPublicKey tree) +{ + // Verify that the tree has the expected tree hash + tree.set_hash_all(); + if (tree.root_hash() != _tree.root_hash()) { + throw ProtocolError("Invalid tree hash"); + } + + _tree = tree; +} + /// /// Subgroup branching /// @@ -2076,12 +2242,12 @@ operator==(const State& lhs, const State& rhs) auto suite = (lhs._suite == rhs._suite); auto group_id = (lhs._group_id == rhs._group_id); auto epoch = (lhs._epoch == rhs._epoch); - auto tree = (lhs._tree == rhs._tree); + auto tree_hash = (lhs._tree.root_hash() == rhs._tree.root_hash()); auto transcript_hash = (lhs._transcript_hash == rhs._transcript_hash); auto key_schedule = (lhs._key_schedule == rhs._key_schedule); auto extensions = (lhs._extensions == rhs._extensions); - return suite && group_id && epoch && tree && transcript_hash && + return suite && group_id && epoch && tree_hash && transcript_hash && key_schedule && extensions; } @@ -2193,6 +2359,12 @@ State::do_export(const std::string& label, GroupInfo State::group_info(bool inline_tree) const +{ + return group_info(true, inline_tree); +} + +GroupInfo +State::group_info(bool external_pub, bool inline_tree) const { auto group_info = GroupInfo{ { @@ -2207,8 +2379,10 @@ State::group_info(bool inline_tree) const _key_schedule.confirmation_tag, }; - group_info.extensions.add( - ExternalPubExtension{ _key_schedule.external_priv.public_key }); + if (external_pub) { + group_info.extensions.add( + ExternalPubExtension{ _key_schedule.external_priv.public_key }); + } if (inline_tree) { group_info.extensions.add(RatchetTreeExtension{ _tree }); diff --git a/src/tree_math.cpp b/src/tree_math.cpp index f1a8be95..869f5b36 100644 --- a/src/tree_math.cpp +++ b/src/tree_math.cpp @@ -47,8 +47,11 @@ LeafCount::full(const LeafCount n) } NodeCount::NodeCount(const LeafCount n) - : UInt32(2 * (n.val - 1) + 1) + : UInt32(0) { + if (n.val > 0) { + val = 2 * (n.val - 1) + 1; + } } LeafIndex::LeafIndex(NodeIndex x) @@ -58,7 +61,7 @@ LeafIndex::LeafIndex(NodeIndex x) throw InvalidParameterError("Only even node indices describe leaves"); } - val = x.val >> 1; // NOLINT(hicpp-signed-bitwise) + val = x.val >> 1; } NodeIndex @@ -165,7 +168,7 @@ NodeIndex::sibling(NodeIndex ancestor) const } std::vector -NodeIndex::dirpath(LeafCount n) +NodeIndex::dirpath(LeafCount n) const { if (val >= NodeCount(n).val) { throw InvalidParameterError("Request for dirpath outside of tree"); @@ -193,7 +196,7 @@ NodeIndex::dirpath(LeafCount n) } std::vector -NodeIndex::copath(LeafCount n) +NodeIndex::copath(LeafCount n) const { auto d = dirpath(n); if (d.empty()) { diff --git a/src/treekem.cpp b/src/treekem.cpp index 6c3fb2c2..245cdb4e 100644 --- a/src/treekem.cpp +++ b/src/treekem.cpp @@ -57,6 +57,15 @@ Node::parent_hash() const return var::visit(get_ph, node); } +/// +/// TreeSlice +/// +bytes +TreeSlice::tree_hash(CipherSuite suite) const +{ + return TreeKEMPublicKey(suite, *this).root_hash(); +} + /// /// TreeKEMPrivateKey /// @@ -116,6 +125,30 @@ TreeKEMPrivateKey::implant(const TreeKEMPublicKey& pub, update_secret = pub.suite.derive_secret(secret, "path"); } +void +TreeKEMPrivateKey::implant_matching(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret) +{ + auto secret = path_secret; + + path_secrets.insert_or_assign(start, secret); + private_key_cache.erase(start); + + const auto dp = start.dirpath(pub.size); + for (const auto& n : dp) { + if (pub.node_at(n).blank()) { + continue; + } + + secret = pub.suite.derive_secret(secret, "path"); + path_secrets.insert_or_assign(n, secret); + private_key_cache.erase(n); + } + + update_secret = pub.suite.derive_secret(secret, "path"); +} + std::optional TreeKEMPrivateKey::private_key(NodeIndex n) const { @@ -209,8 +242,13 @@ TreeKEMPublicKey::dump() const std::cout << "Tree:" << std::endl; auto width = NodeCount(size); for (auto i = NodeIndex{ 0 }; i.val < width.val; i.val++) { + const auto known = nodes.count(i) > 0; + const auto blank = known && node_at(i).blank(); + printf(" %03d : ", i.val); // NOLINT - if (!node_at(i).blank()) { + if (!known) { + std::cout << "????????"; + } else if (!blank) { auto pkRm = to_hex(opt::get(node_at(i).node).public_key().data); std::cout << pkRm.substr(0, 8); } else { @@ -222,7 +260,9 @@ TreeKEMPublicKey::dump() const std::cout << " "; } - if (!node_at(i).blank()) { + if (!known) { + std::cout << "?"; + } else if (!blank) { std::cout << "X"; if (!i.is_leaf()) { @@ -254,7 +294,7 @@ TreeKEMPrivateKey::decap(LeafIndex from, const auto priv = opt::get(private_key(decrypt_node)); const auto path_secret = priv.decrypt( suite, encrypt_label::update_path_node, context, encrypted_path_secret); - implant(pub, overlap_node, path_secret); + implant_matching(pub, overlap_node, path_secret); // Check that the resulting state is consistent with the public key if (!consistent(pub)) { @@ -400,6 +440,14 @@ TreeKEMPublicKey::TreeKEMPublicKey(CipherSuite suite_in) { } +TreeKEMPublicKey::TreeKEMPublicKey(CipherSuite suite_in, const TreeSlice& slice) + : suite(suite_in) + , size(slice.n_leaves) +{ + implant_slice_unchecked(slice); + set_hash_all(); +} + LeafIndex TreeKEMPublicKey::allocate_leaf() { @@ -411,12 +459,17 @@ TreeKEMPublicKey::allocate_leaf() // Extend the tree if necessary if (index.val >= size.val) { + const auto prev_width = NodeCount(size); + if (size.val == 0) { size.val = 1; - nodes.resize(1); } else { size.val *= 2; - nodes.resize(2 * nodes.size() + 1); + } + + const auto new_width = NodeCount(size); + for (auto i = NodeIndex(prev_width.val); i < new_width; i.val++) { + nodes.insert_or_assign(i, OptionalNode{}); } } @@ -576,6 +629,12 @@ TreeKEMPublicKey::parent_hash_valid() const return true; } +bool +TreeKEMPublicKey::is_complete() const +{ + return nodes.size() == NodeCount{ size }.val; +} + std::vector TreeKEMPublicKey::resolve(NodeIndex index) const { @@ -605,6 +664,129 @@ TreeKEMPublicKey::resolve(NodeIndex index) const return l; } +TreeSlice +TreeKEMPublicKey::extract_slice(LeafIndex leaf) const +{ + if (!(leaf < size)) { + throw InvalidParameterError("Invalid leaf index"); + } + + const auto n = NodeIndex(leaf); + auto dirpath = n.dirpath(size); + dirpath.insert(dirpath.begin(), n); + const auto dirpath_nodes = stdx::transform( + dirpath, [this](const auto& n) { return node_at(n); }); + + const auto copath = n.copath(size); + const auto copath_hashes = stdx::transform( + copath, [this](const auto& n) { return hashes.at(n); }); + + return { leaf, size, dirpath_nodes, copath_hashes }; +} + +void +TreeKEMPublicKey::implant_slice(const TreeSlice& slice) +{ + if (slice.n_leaves != size) { + throw InvalidParameterError("Slice tree size does not match tree size"); + } + + if (slice.tree_hash(suite) != root_hash()) { + throw InvalidParameterError("Slice tree hash does not match tree hash"); + } + + implant_slice_unchecked(slice); +} + +std::tuple +TreeKEMPublicKey::slice_path(UpdatePath path, + LeafIndex from, + LeafIndex to) const +{ + const auto toi = NodeIndex(to); + const auto fdp = filtered_direct_path(NodeIndex(from)); + + for (auto i = size_t(0); i < fdp.size(); i++) { + const auto& [dpi, res] = fdp.at(i); + + if (!toi.is_below(dpi)) { + continue; + } + + for (auto j = size_t(0); j < res.size(); j++) { + const auto resi = res.at(j); + if (!toi.is_below(resi)) { + continue; + } + + return { path.nodes.at(i).encrypted_path_secret.at(j), resi }; + } + } + + throw ProtocolError("Decryption node not found"); +} + +void +TreeKEMPublicKey::implant_slice_unchecked(const TreeSlice& slice) +{ + const auto n = NodeIndex(slice.leaf_index); + auto dirpath = n.dirpath(size); + dirpath.insert(dirpath.begin(), n); + const auto copath = n.copath(size); + + if (slice.direct_path_nodes.size() != dirpath.size()) { + throw InvalidParameterError("Malformed tree slice (bad direct path size)"); + } + + if (slice.copath_hashes.size() != copath.size()) { + throw InvalidParameterError("Malformed tree slice (bad copath size)"); + } + + for (auto i = size_t(0); i < dirpath.size(); i++) { + nodes.insert_or_assign(dirpath.at(i), slice.direct_path_nodes.at(i)); + } + + for (auto i = size_t(0); i < copath.size(); i++) { + hashes.insert_or_assign(copath.at(i), slice.copath_hashes.at(i)); + } +} + +TreeKEMPublicKey::AncestorIndex +TreeKEMPublicKey::ancestor_index(LeafIndex to, LeafIndex from) const +{ + // Find the index of the common ancestor in the filtered direct path + // + // XXX(RLB): This calculation is only guaranteed to be accurate immediately + // after a commit from `to`. But it has the advantage of being computable by + // a light client, and is only used when a light client processes a commit. + const auto from_dp = NodeIndex(from).dirpath(size); + const auto fdp = stdx::filter( + from_dp, [&](const auto& n) { return !node_at(n).blank(); }); + + const auto ancestor = to.ancestor(from); + const auto it = stdx::find(fdp, ancestor); + if (it == fdp.end()) { + throw ProtocolError("Blank common ancestor node"); + } + const auto ancestor_node_index = static_cast(it - fdp.begin()); + + // Find the nex non-blank node underneath the ancestor node + const auto to_dp = NodeIndex(to).dirpath(size); + const auto candidates = stdx::filter(to_dp, [&](const auto& n) { + return n.is_below(ancestor) && n != ancestor && !node_at(n).blank(); + }); + + auto resolution_node = NodeIndex(to); + if (!candidates.empty()) { + resolution_node = candidates.back(); + } + + return { + ancestor_node_index, + resolution_node, + }; +} + TreeKEMPublicKey::DecapCoords TreeKEMPublicKey::decap_coords( LeafIndex to, @@ -678,6 +860,11 @@ std::optional TreeKEMPublicKey::find(const LeafNode& leaf) const { for (LeafIndex i{ 0 }; i < size; i.val++) { + if (nodes.count(NodeIndex{ i }) == 0) { + // Unknown leaf node + continue; + } + const auto& node = node_at(i); if (!node.blank() && node.leaf_node() == leaf) { return i; @@ -762,8 +949,9 @@ TreeKEMPublicKey::encap(const TreeKEMPrivateKey& priv, auto ct = stdx::transform(res, [&](auto nr) { const auto& node_pub = opt::get(node_at(nr).node).public_key(); - return node_pub.encrypt( + auto ct = node_pub.encrypt( suite, encrypt_label::update_path_node, context, path_secret); + return ct; }); return UpdatePathNode{ node_priv.public_key, std::move(ct) }; @@ -797,41 +985,32 @@ TreeKEMPublicKey::truncate() return; } - // Remove the right subtree until the tree is of minimal size + // Find the new size of the tree while (size.val / 2 > index.val) { - nodes.resize(nodes.size() / 2); size.val /= 2; } + + // Delete nodes to right of the new smaller edge of the tree + const auto node_size = NodeCount(size); + const auto start = + std::find_if(nodes.begin(), nodes.end(), [node_size](const auto& n) { + return !(n.first < node_size); + }); + if (start != nodes.end()) { + nodes.erase(start, nodes.end()); + } } OptionalNode& TreeKEMPublicKey::node_at(NodeIndex n) { - auto width = NodeCount(size); - if (n.val >= width.val) { - throw InvalidParameterError("Node index not in tree"); - } - - if (n.val >= nodes.size()) { - return blank_node; - } - - return nodes.at(n.val); + return nodes.at(n); } const OptionalNode& TreeKEMPublicKey::node_at(NodeIndex n) const { - auto width = NodeCount(size); - if (n.val >= width.val) { - throw InvalidParameterError("Node index not in tree"); - } - - if (n.val >= nodes.size()) { - return blank_node; - } - - return nodes.at(n.val); + return nodes.at(n); } OptionalNode& @@ -1074,11 +1253,39 @@ TreeKEMPublicKey::original_parent_hash(TreeHashCache& cache, })); } +bool +TreeKEMPublicKey::parent_hash_valid(LeafIndex from) const +{ + // Synthesize a filtered direct path and UpdatePath from the non-blank + // ancestors. Since this is checking for a whole path, we don't need to check + // that the resolution is non-empty. + auto dp = NodeIndex(from).dirpath(size); + auto fdpn = + stdx::filter(dp, [&](auto n) { return !node_at(n).blank(); }); + auto fdp = stdx::transform( + fdpn, [&](auto n) { return std::make_tuple(n, std::vector{}); }); + + auto path_nodes = stdx::transform(fdpn, [&](auto n) { + return UpdatePathNode{ node_at(n).parent_node().public_key, {} }; + }); + auto path = UpdatePath{ node_at(from).leaf_node(), path_nodes }; + + return parent_hash_valid(from, path, fdp); +} + bool TreeKEMPublicKey::parent_hash_valid(LeafIndex from, const UpdatePath& path) const { auto fdp = filtered_direct_path(NodeIndex(from)); + return parent_hash_valid(from, path, fdp); +} + +bool +TreeKEMPublicKey::parent_hash_valid(LeafIndex from, + const UpdatePath& path, + const FilteredDirectPath& fdp) const +{ auto hash_chain = parent_hashes(from, fdp, path.nodes); auto leaf_ph = var::visit(overloaded{ @@ -1130,9 +1337,14 @@ operator<<(tls::ostream& str, const TreeKEMPublicKey& obj) cut.val -= 1; } - const auto begin = obj.nodes.begin(); - const auto end = begin + NodeIndex(cut).val + 1; - const auto view = std::vector(begin, end); + auto node_cut = NodeIndex(cut); + node_cut.val += 1; + + auto view = std::vector(node_cut.val); + for (auto i = NodeIndex(0); i < node_cut; i.val++) { + view.at(i.val) = obj.nodes.at(i); + } + return str << view; } @@ -1140,37 +1352,47 @@ tls::istream& operator>>(tls::istream& str, TreeKEMPublicKey& obj) { // Read the node list - str >> obj.nodes; - if (obj.nodes.empty()) { + std::vector nodes; + str >> nodes; + if (nodes.empty()) { return str; } // Verify that the tree is well-formed and minimal - if (obj.nodes.size() % 2 == 0) { + if (nodes.size() % 2 == 0) { throw ProtocolError("Malformed ratchet tree: even number of nodes"); } - if (obj.nodes.back().blank()) { + if (nodes.back().blank()) { throw ProtocolError("Ratchet tree does not use minimal encoding"); } // Adjust the size value to fit the non-blank nodes obj.size.val = 1; - while (NodeCount(obj.size).val < obj.nodes.size()) { + while (NodeCount(obj.size).val < nodes.size()) { obj.size.val *= 2; } - // Add blank nodes to the end - obj.nodes.resize(NodeCount(obj.size).val); + // Copy nodes to `obj` and add blank nodes to the end + for (uint32_t i = 0; i < nodes.size(); i++) { + obj.nodes.insert_or_assign(NodeIndex(i), std::move(nodes.at(i))); + } + + const auto node_size = NodeCount(obj.size); + const auto provided_node_count = static_cast(nodes.size()); + for (uint32_t i = provided_node_count; i < node_size.val; i++) { + obj.nodes.insert_or_assign(NodeIndex(i), OptionalNode{}); + } // Verify the basic structure of the tree is sane - for (size_t i = 0; i < obj.nodes.size(); i++) { - if (obj.nodes[i].blank()) { + for (auto i = NodeIndex{ 0 }; i < node_size; i.val++) { + const auto& maybe_node = obj.node_at(i); + if (maybe_node.blank()) { continue; } - const auto& node = opt::get(obj.nodes[i].node).node; - auto at_leaf = (i % 2 == 0); + const auto& node = opt::get(maybe_node.node).node; + auto at_leaf = (i.val % 2 == 0); auto holds_leaf = var::holds_alternative(node); auto holds_parent = var::holds_alternative(node); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 06ffbd4c..2576c708 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,9 +4,9 @@ set(TEST_APP_NAME "${LIB_NAME}_test") file(GLOB TEST_SOURCES CONFIGURE_DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) add_executable(${TEST_APP_NAME} ${TEST_SOURCES}) -add_dependencies(${TEST_APP_NAME} ${LIB_NAME} bytes tls_syntax mls_vectors) +add_dependencies(${TEST_APP_NAME} ${LIB_NAME} bytes tls_syntax mls_vectors mls_ds) target_include_directories(${TEST_APP_NAME} PRIVATE ${PROJECT_SOURCE_DIR}/src) -target_link_libraries(${TEST_APP_NAME} PRIVATE mls_vectors Catch2::Catch2WithMain) +target_link_libraries(${TEST_APP_NAME} PRIVATE mls_vectors mls_ds Catch2::Catch2WithMain) # Enable CTest catch_discover_tests(${TEST_APP_NAME}) diff --git a/test/state.cpp b/test/state.cpp index c75b60e9..953a5172 100644 --- a/test/state.cpp +++ b/test/state.cpp @@ -2,8 +2,10 @@ #include #include #include +#include using namespace MLS_NAMESPACE; +using namespace MLS_NAMESPACE::mls_ds; struct CustomExtension { @@ -419,6 +421,260 @@ TEST_CASE_METHOD(StateTest, "Two Person with Replacement") verify_group_functionality(group); } +TEST_CASE_METHOD(StateTest, "Light client can participate") +{ + // Initialize the creator's state + auto first0 = State{ group_id, + suite, + leaf_privs[0], + identity_privs[0], + key_packages[0].leaf_node, + {} }; + + // Add the second participant + auto add1 = first0.add_proposal(key_packages[1]); + auto [commit1, welcome1, first1_] = + first0.commit(fresh_secret(), CommitOpts{ { add1 }, true, false, {} }, {}); + silence_unused(commit1); + auto first1 = first1_; + + // Initialize the second participant from the Welcome. Note that the second + // participant is always a full client, because the membership proofs cover + // the whole tree. + auto second1 = State{ init_privs[1], + leaf_privs[1], + identity_privs[1], + key_packages[1], + welcome1, + std::nullopt, + {} }; + + REQUIRE(second1.is_full_client()); + REQUIRE(first1 == second1); + + // Add the third participant + auto add2 = first0.add_proposal(key_packages[2]); + auto [commit2, welcome2, first2_] = + first1.commit(fresh_secret(), CommitOpts{ { add2 }, false, false, {} }, {}); + auto first2 = first2_; + const auto annotated_welcome = AnnotatedWelcome::from( + welcome2, first2.tree(), LeafIndex{ 0 }, LeafIndex{ 2 }); + + // Handle the Commit at the second participant + auto second2 = opt::get(second1.handle(commit2)); + + // Initialize the third participant as a light client, by only including + // membership proofs in the Welcome, not the full tree + auto third2 = State{ init_privs[2], + leaf_privs[2], + identity_privs[2], + key_packages[2], + annotated_welcome.welcome, + annotated_welcome.tree(), + {} }; + REQUIRE_FALSE(third2.is_full_client()); + + REQUIRE(first2 == second2); + REQUIRE(first2 == third2); + + // Create another commit and handle it at the second client + auto [commit3, welcome3, first3_] = first2.commit(fresh_secret(), {}, {}); + silence_unused(welcome3); + auto first3 = first3_; + auto second3 = opt::get(second2.handle(commit3)); + + // Verify that the light client refuses to process it on its own + REQUIRE_THROWS(third2.handle(commit3)); + + // Convert the Commit to an AnnotatedCommit + auto annotated_commit = AnnotatedCommit::from( + third2.index(), {}, commit3, first2.tree(), first3.tree()); + + // Verify that the light client can process the commit with a commit map + auto third3 = third2.handle(annotated_commit); + + REQUIRE(first3 == second3); + REQUIRE(first3 == third3); + + // Upgrade the third client to be a full client + third3.upgrade_to_full_client(first3.tree()); + REQUIRE(third3.is_full_client()); + + // Verify that all three clients can now process a normal Commit + auto [commit4, welcome4, first4_] = first3.commit(fresh_secret(), {}, {}); + silence_unused(welcome4); + auto first4 = first4_; + auto second4 = opt::get(second3.handle(commit4)); + auto third4 = opt::get(third3.handle(commit4)); + + REQUIRE(first4 == second4); + REQUIRE(first4 == third4); +} + +TEST_CASE_METHOD(StateTest, "Light client can upgrade after several commits") +{ + // Initialize the first two users + auto first0 = State{ group_id, + suite, + leaf_privs[0], + identity_privs[0], + key_packages[0].leaf_node, + {} }; + + auto add1 = first0.add_proposal(key_packages[1]); + auto [commit1, welcome1, first1_] = + first0.commit(fresh_secret(), CommitOpts{ { add1 }, true, false, {} }, {}); + silence_unused(commit1); + auto first1 = first1_; + + auto second1 = State{ init_privs[1], + leaf_privs[1], + identity_privs[1], + key_packages[1], + welcome1, + std::nullopt, + {} }; + + REQUIRE(second1.is_full_client()); + REQUIRE(first1 == second1); + + // Add the third participant as a light client, remembering the tree at this + // point. + auto add2 = first0.add_proposal(key_packages[2]); + auto [commit2, welcome2, first2_] = + first1.commit(fresh_secret(), CommitOpts{ { add2 }, false, false, {} }, {}); + auto first2 = first2_; + const auto annotated_welcome = AnnotatedWelcome::from( + welcome2, first2.tree(), LeafIndex{ 0 }, LeafIndex{ 2 }); + + auto second2 = opt::get(second1.handle(commit2)); + + auto third2 = State{ init_privs[2], + leaf_privs[2], + identity_privs[2], + key_packages[2], + annotated_welcome.welcome, + annotated_welcome.tree(), + {} }; + REQUIRE_FALSE(third2.is_full_client()); + + REQUIRE(first2 == second2); + REQUIRE(first2 == third2); + + const auto tree2 = first2.tree(); + + // Client 1 makes a bunch of commits, and the other two members follow along. + auto first = first2; + auto second = second2; + auto third = third2; + + auto commits = std::vector{}; + const auto n_commits = size_t(5); + for (auto i = size_t(0); i < n_commits; i++) { + const auto [commit, welcome, next_first] = + first.commit(fresh_secret(), {}, {}); + silence_unused(welcome); + const auto annotated_commit = AnnotatedCommit::from( + third.index(), {}, commit, first.tree(), next_first.tree()); + + commits.push_back(commit); + + first = next_first; + second = opt::get(second.handle(commit)); + third = third.handle(annotated_commit); + + REQUIRE(first == second); + REQUIRE(first == third); + } + + // Client 3 finally finishes downloading the tree, fast-forwards it using the + // commit queue, and upgrades to being a full client. + auto follower = TreeFollower(tree2); + for (const auto& commit : commits) { + follower.update(commit, {}); + } + + third.upgrade_to_full_client(follower.tree()); + REQUIRE(third.is_full_client()); + + REQUIRE(first == second); + REQUIRE(first == third); +} + +TEST_CASE_METHOD(StateTest, "Light client can handle an external commit") +{ + // Initialize the first two users + auto first0 = State{ group_id, + suite, + leaf_privs[0], + identity_privs[0], + key_packages[0].leaf_node, + {} }; + + auto add1 = first0.add_proposal(key_packages[1]); + auto [commit1, welcome1, first1_] = + first0.commit(fresh_secret(), CommitOpts{ { add1 }, true, false, {} }, {}); + silence_unused(commit1); + auto first1 = first1_; + + auto second1 = State{ init_privs[1], + leaf_privs[1], + identity_privs[1], + key_packages[1], + welcome1, + std::nullopt, + {} }; + + REQUIRE(second1.is_full_client()); + REQUIRE(first1 == second1); + + // Add the third participant as a light client + auto add2 = first0.add_proposal(key_packages[2]); + auto [commit2, welcome2, first2_] = + first1.commit(fresh_secret(), CommitOpts{ { add2 }, false, false, {} }, {}); + auto first2 = first2_; + const auto annotated_welcome = AnnotatedWelcome::from( + welcome2, first2.tree(), LeafIndex{ 0 }, LeafIndex{ 2 }); + + auto second2 = opt::get(second1.handle(commit2)); + + auto third2 = State{ init_privs[2], + leaf_privs[2], + identity_privs[2], + key_packages[2], + annotated_welcome.welcome, + annotated_welcome.tree(), + {} }; + REQUIRE_FALSE(third2.is_full_client()); + + REQUIRE(first2 == second2); + REQUIRE(first2 == third2); + + // The fourth participant joins via an external commit + const auto group_info = first2.group_info(true); + const auto [commit3, fourth3] = State::external_join(fresh_secret(), + identity_privs[3], + key_packages[3], + group_info, + std::nullopt, + {}, + std::nullopt, + {}); + + // Process the commit at the normal clients + const auto first3 = opt::get(first2.handle(commit3)); + const auto second3 = opt::get(second2.handle(commit3)); + + // Annotate the commit and handle it at the third client + const auto annotated_commit = AnnotatedCommit::from( + third2.index(), {}, commit3, first2.tree(), first3.tree()); + const auto third3 = third2.handle(annotated_commit); + + REQUIRE(first3 == second3); + REQUIRE(first3 == third3); + REQUIRE(first3 == fourth3); +} + TEST_CASE_METHOD(StateTest, "External Join") { // Initialize the creator's state diff --git a/test/treekem.cpp b/test/treekem.cpp index 5cb5ec89..833393f8 100644 --- a/test/treekem.cpp +++ b/test/treekem.cpp @@ -68,7 +68,7 @@ TEST_CASE_METHOD(TreeKEMTest, "TreeKEM Private Key") const auto priv2 = HPKEPrivateKey::generate(suite); const auto hash_size = suite.digest().hash_size; - // Create a tree with N blank leaves + // Create a tree with N leaves, blank otherwise auto pub = TreeKEMPublicKey(suite); for (auto i = uint32_t(0); i < size.val; i++) { auto [_leaf_priv, _sig_priv, leaf_node] = new_leaf_node(); @@ -264,6 +264,32 @@ TEST_CASE_METHOD(TreeKEMTest, "TreeKEM encap/decap") REQUIRE(privs[j].consistent(pubs[j])); } } + + // XXX(RLB) The below test probably doesn't belong here, but we have a nice + // properly-created tree here, so it's convenient to use it for testing tree + // slices. + + // Verify that all slices of the tree have the correct tree hash + auto original = pubs[0]; + original.set_hash_all(); + const auto tree_hash = original.root_hash(); + auto slices = std::vector{}; + for (uint32_t i = 0; i < original.size.val; i++) { + auto slice = original.extract_slice(LeafIndex{ i }); + REQUIRE(slice.tree_hash(suite) == tree_hash); + + slices.push_back(slice); + } + + // Verify that the tree can be reassembled from the slices + auto reconstructed = TreeKEMPublicKey(suite, slices[0]); + for (uint32_t i = 1; i < slices.size(); i++) { + reconstructed.implant_slice(slices[i]); + } + + reconstructed.set_hash_all(); + REQUIRE(reconstructed.root_hash() == tree_hash); + REQUIRE(reconstructed.parent_hash_valid()); } TEST_CASE("TreeKEM Interop", "[.][all]")