diff --git a/include/mls/core_types.h b/include/mls/core_types.h index 533c6a8c..609ce9f2 100644 --- a/include/mls/core_types.h +++ b/include/mls/core_types.h @@ -41,6 +41,9 @@ struct ExtensionType static constexpr Extension::Type external_pub = 4; static constexpr Extension::Type external_senders = 5; + static constexpr Extension::Type flags = 6; + static constexpr Extension::Type membership_proof = 7; + // XXX(RLB) There is no IANA-registered type for this extension yet, so we use // a value from the vendor-specific space static constexpr Extension::Type sframe_parameters = 0xff02; diff --git a/include/mls/messages.h b/include/mls/messages.h index c7442a8f..d356cc73 100644 --- a/include/mls/messages.h +++ b/include/mls/messages.h @@ -27,6 +27,14 @@ struct RatchetTreeExtension TLS_SERIALIZABLE(tree) }; +struct MembershipProofExtension +{ + std::vector slices; + + static const uint16_t type; + TLS_SERIALIZABLE(slices) +}; + struct ExternalSender { SignaturePublicKey signature_key; @@ -43,6 +51,20 @@ struct ExternalSendersExtension TLS_SERIALIZABLE(senders); }; +struct FlagsExtension +{ + std::vector flag_data; + + void set(size_t pos); + void unset(size_t pos); + bool get(size_t pos) const; + + static const uint16_t type; + + // XXX(RLB): This should check for extra zero bytes on deserialize. + TLS_SERIALIZABLE(flag_data); +}; + struct SFrameParameters { uint16_t cipher_suite; @@ -257,6 +279,15 @@ struct Welcome const std::vector& psks); }; +struct LightCommit +{ + GroupContext group_context; + bytes confirmation_tag; + TreeSlice sender_membership_proof; + std::optional encrypted_path_secret; + std::optional decryption_node_index; +}; + /// /// Proposals & Commit /// @@ -623,6 +654,10 @@ struct PublicMessage bytes membership_mac(CipherSuite suite, const bytes& membership_key, const std::optional& context) const; + + // XXX(RLB) This is a hack to avoid unwrapping across epochs. We should do + // something more elegant, like unchecked_content() + friend class State; }; struct PrivateMessage diff --git a/include/mls/state.h b/include/mls/state.h index 1741d10d..99dfbd35 100644 --- a/include/mls/state.h +++ b/include/mls/state.h @@ -19,9 +19,20 @@ struct RosterIndex : public UInt32 struct CommitOpts { + // Include these proposals in the commit by value std::vector extra_proposals; - bool inline_tree; - bool force_path; + + // Send a ratchet_tree extension in the Welcome + bool inline_tree = false; + + // Send an UpdatePath even if none is required + bool force_path = false; + + // Send a membership_proof extension in the Welcome covering the committer and + // the new joiners + bool membership_proof = false; + + // Update the committer's LeafNode in the following way LeafNodeOptions leaf_node_opts; }; @@ -127,6 +138,12 @@ class State std::optional handle(const ValidatedContent& content_auth, std::optional cached_state); + /// + /// Light MLS + /// + LightCommit lighten_for(LeafIndex leaf, const MLSMessage& commit) const; + State handle(const LightCommit& light_commit) const; + /// /// PSK management /// @@ -145,6 +162,7 @@ class State const ExtensionList& extensions() const { return _extensions; } const TreeKEMPublicKey& tree() const { return _tree; } const bytes& resumption_psk() const { return _key_schedule.resumption_psk; } + bool is_full_client() const { return _tree.is_complete(); } bytes do_export(const std::string& label, const bytes& context, diff --git a/include/mls/tree_math.h b/include/mls/tree_math.h index 19e12d7b..8d9d674a 100644 --- a/include/mls/tree_math.h +++ b/include/mls/tree_math.h @@ -99,8 +99,8 @@ struct NodeIndex : public UInt32 // of `ancestor` that is not in the direct path of this node. NodeIndex sibling(NodeIndex ancestor) const; - std::vector dirpath(LeafCount n); - std::vector copath(LeafCount n); + std::vector dirpath(LeafCount n) const; + std::vector copath(LeafCount n) const; uint32_t level() const; }; diff --git a/include/mls/treekem.h b/include/mls/treekem.h index 75a9550f..8b14545a 100644 --- a/include/mls/treekem.h +++ b/include/mls/treekem.h @@ -59,6 +59,19 @@ struct OptionalNode TLS_SERIALIZABLE(node) }; +struct TreeSlice +{ + LeafIndex leaf_index; + LeafCount n_leaves; + std::vector direct_path_nodes; + std::vector copath_hashes; + + bytes tree_hash(CipherSuite suite) const; + void add(const TreeSlice& other); + + TLS_SERIALIZABLE(leaf_index, n_leaves, direct_path_nodes, copath_hashes); +}; + struct TreeKEMPublicKey; struct TreeKEMPrivateKey @@ -107,15 +120,19 @@ struct TreeKEMPrivateKey void implant(const TreeKEMPublicKey& pub, NodeIndex start, const bytes& path_secret); + void implant_matching(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret); }; struct TreeKEMPublicKey { CipherSuite suite; LeafCount size{ 0 }; - std::vector nodes; + std::map nodes; explicit TreeKEMPublicKey(CipherSuite suite); + TreeKEMPublicKey(CipherSuite suite, const TreeSlice& slice); TreeKEMPublicKey() = default; TreeKEMPublicKey(const TreeKEMPublicKey& other) = default; @@ -144,12 +161,19 @@ struct TreeKEMPublicKey bool parent_hash_valid(LeafIndex from, const UpdatePath& path) const; bool parent_hash_valid() const; + bool is_complete() const; bool has_leaf(LeafIndex index) const; std::optional find(const LeafNode& leaf) const; std::optional leaf_node(LeafIndex index) const; std::vector resolve(NodeIndex index) const; + TreeSlice extract_slice(LeafIndex leaf) const; + void implant_slice(const TreeSlice& slice); + std::tuple slice_path(UpdatePath path, + LeafIndex from, + LeafIndex to) const; + template bool all_leaves(const UnaryPredicate& pred) const { @@ -228,6 +252,8 @@ struct TreeKEMPublicKey bool exists_in_tree(const SignaturePublicKey& key, std::optional except) const; + void implant_slice_unchecked(const TreeSlice& slice); + OptionalNode blank_node; friend struct TreeKEMPrivateKey; diff --git a/src/messages.cpp b/src/messages.cpp index 8d483193..74366021 100644 --- a/src/messages.cpp +++ b/src/messages.cpp @@ -12,12 +12,71 @@ namespace MLS_NAMESPACE { const Extension::Type ExternalPubExtension::type = ExtensionType::external_pub; const Extension::Type RatchetTreeExtension::type = ExtensionType::ratchet_tree; +const Extension::Type MembershipProofExtension::type = + ExtensionType::membership_proof; const Extension::Type ExternalSendersExtension::type = ExtensionType::external_senders; +const Extension::Type FlagsExtension::type = ExtensionType::flags; const Extension::Type SFrameParameters::type = ExtensionType::sframe_parameters; const Extension::Type SFrameCapabilities::type = ExtensionType::sframe_parameters; +void +FlagsExtension::set(size_t pos) +{ + const auto byte_pos = pos >> 3; + const auto bit_pos = pos & 0x07; + + // Ensure space + if (byte_pos >= flag_data.size()) { + flag_data.resize(byte_pos + 1); + } + + // Set the bit + flag_data.at(byte_pos) |= uint8_t(1 << bit_pos); +} + +void +FlagsExtension::unset(size_t pos) +{ + const auto byte_pos = pos >> 3; + const auto bit_pos = pos & 0x07; + + if (byte_pos >= flag_data.size()) { + return; + } + + // Unset the bit + flag_data.at(byte_pos) &= ~uint8_t(1 << bit_pos); + + // Trim any zero bytes + auto cut = flag_data.size() - 1; + while (cut > 0 && flag_data.at(cut) == 0) { + cut -= 1; + } + + if (flag_data.at(cut) == 0) { + flag_data.clear(); + return; + } + + flag_data.resize(cut + 1); +} + +bool +FlagsExtension::get(size_t pos) const +{ + const auto byte_pos = pos >> 3; + const auto bit_pos = pos & 0x07; + + if (byte_pos >= flag_data.size()) { + return false; + } + + const auto bit = (flag_data.at(byte_pos) >> bit_pos) & 0x01; + return bit == 1; +} + bool SFrameCapabilities::compatible(const SFrameParameters& params) const { diff --git a/src/session.cpp b/src/session.cpp index e1b615cb..f0aa0363 100644 --- a/src/session.cpp +++ b/src/session.cpp @@ -305,8 +305,10 @@ Session::commit() { auto commit_secret = inner->fresh_secret(); auto encrypt = inner->encrypt_handshake; - auto [commit, welcome, new_state] = inner->history.front().commit( - commit_secret, CommitOpts{ {}, true, encrypt, {} }, { encrypt, {}, 0 }); + auto [commit, welcome, new_state] = + inner->history.front().commit(commit_secret, + CommitOpts{ {}, true, encrypt, false, {} }, + { encrypt, {}, 0 }); auto commit_msg = tls::marshal(commit); auto welcome_msg = tls::marshal(welcome); diff --git a/src/state.cpp b/src/state.cpp index 826503ba..60771eb0 100644 --- a/src/state.cpp +++ b/src/state.cpp @@ -29,13 +29,13 @@ State::State(bytes group_id, } _index = _tree.add_leaf(leaf_node); - _tree.set_hash_all(); _tree_priv = TreeKEMPrivateKey::solo(suite, _index, std::move(enc_priv)); if (!_tree_priv.consistent(_tree)) { throw InvalidParameterError("LeafNode inconsistent with private key"); } // XXX(RLB): Convert KeyScheduleEpoch to take GroupContext? + _tree.set_hash_all(); auto ctx = tls::marshal(group_context()); _key_schedule = KeyScheduleEpoch(_suite, random_bytes(_suite.secret_size()), ctx); @@ -53,10 +53,20 @@ State::import_tree(const bytes& tree_hash, { auto tree = TreeKEMPublicKey(_suite); auto maybe_tree_extn = extensions.find(); + auto maybe_membership_proof_extn = + extensions.find(); + if (external) { tree = opt::get(external); } else if (maybe_tree_extn) { tree = opt::get(maybe_tree_extn).tree; + } else if (maybe_membership_proof_extn) { + const auto& membership_proof = opt::get(maybe_membership_proof_extn); + + tree = TreeKEMPublicKey(_suite, membership_proof.slices.at(0)); + for (auto i = size_t(0); i < membership_proof.slices.size(); i++) { + tree.implant_slice(membership_proof.slices.at(i)); + } } else { throw InvalidParameterError("No tree available"); } @@ -74,6 +84,12 @@ State::import_tree(const bytes& tree_hash, bool State::validate_tree() const { + // If we don't have a full tree, then we can't verify any properties + // TODO(RLB) We can in fact verify that the leaf node signatures are valid. + if (!_tree.is_complete()) { + return true; + } + // The functionality here is somewhat duplicative of State::valid(const // LeafNode&). Simply calling that method, however, would result in this // method having quadratic scaling, since each call to valid() does a linear @@ -640,6 +656,11 @@ State::commit(const bytes& leaf_secret, const MessageOpts& msg_opts, CommitParams params) { + // If we are not a full client, we can't handle commits + if (!is_full_client()) { + throw ProtocolError("Light clients can't create commits"); + } + // Construct a commit from cached proposals // TODO(rlb) ignore some proposals: // * Update after Update @@ -757,6 +778,26 @@ State::commit(const bytes& leaf_secret, auto commit_message = protect(std::move(commit_content_auth), msg_opts.padding_size); + // If we are adding membership proofs, add one covering this client and the + // joiners. + auto group_info_extensions = ExtensionList{}; + if (opts && opt::get(opts).membership_proof) { + auto slices = std::vector{}; + slices.reserve(joiner_locations.size() + 1); + + slices.push_back(next._tree.extract_slice(next._index)); + for (const auto& loc : joiner_locations) { + slices.push_back(next._tree.extract_slice(loc)); + } + + group_info_extensions.add(MembershipProofExtension{ std::move(slices) }); + } + + // If we are sending the whole tree, add that extension + if (opts && opt::get(opts).inline_tree) { + group_info_extensions.add(RatchetTreeExtension{ next._tree }); + } + // Complete the GroupInfo and form the Welcome auto group_info = GroupInfo{ { @@ -767,12 +808,9 @@ State::commit(const bytes& leaf_secret, next._transcript_hash.confirmed, next._extensions, }, - { /* No other extensions */ }, + { group_info_extensions }, { confirmation_tag }, }; - if (opts && opt::get(opts).inline_tree) { - group_info.extensions.add(RatchetTreeExtension{ next._tree }); - } group_info.sign(next._tree, next._index, next._identity_priv); auto welcome = @@ -857,6 +895,11 @@ State::handle(const ValidatedContent& val_content, throw InvalidParameterError("Invalid content type"); } + // If we are not a full client, we can't handle commits + if (!is_full_client()) { + throw ProtocolError("Light clients can't handle commits"); + } + switch (content.sender.sender_type()) { case SenderType::member: case SenderType::new_member_commit: @@ -968,6 +1011,97 @@ State::handle(const ValidatedContent& val_content, return next; } +LightCommit +State::lighten_for(LeafIndex leaf, const MLSMessage& commit_msg) const +{ + // Check that the current epoch is one higher than commit.epoch + if (_epoch != commit_msg.epoch() + 1) { + throw InvalidParameterError("Invalid epoch for lightening operation"); + } + + // Pull the GroupContext for the current state + const auto ctx = group_context(); + + // Make a memberhsip proof + const auto& public_msg = var::get(commit_msg.message); + const auto& sender = + var::get(public_msg.content.sender.sender).sender; + + const auto confirmation_tag = opt::get(public_msg.auth.confirmation_tag); + const auto sender_membership_proof = _tree.extract_slice(sender); + + // Extract the correct path secret for the recipient + const auto& commit = var::get(public_msg.content.content); + auto encrypted_path_secret = std::optional{}; + auto decryption_node_index = std::optional{}; + if (commit.path) { + const auto [secret, index] = + _tree.slice_path(opt::get(commit.path), sender, leaf); + encrypted_path_secret = secret; + decryption_node_index = index; + } + + return { + ctx, + confirmation_tag, + sender_membership_proof, + encrypted_path_secret, + decryption_node_index, + }; +} + +State +State::handle(const LightCommit& light_commit) const +{ + // Verify the membership proof + // TODO(RLB) Also verify the signature (?) + // XXX(RLB) Should this use the new or old tree hash? + if (light_commit.sender_membership_proof.tree_hash(_suite) != + light_commit.group_context.tree_hash) { + throw ProtocolError("Invalid sender membership proof"); + } + + // Import the GroupContext + // TODO(RLB) Verify that version, cipher_suite, group_id, epoch are as + // expected + auto next = successor(); + next._epoch += 1; + next._tree = + TreeKEMPublicKey(next._suite, light_commit.sender_membership_proof); + next._extensions = light_commit.group_context.extensions; + + // Decrypt the commit secret + auto commit_secret = _suite.zero(); + if (light_commit.encrypted_path_secret) { + const auto& encrypted_path_secret = + opt::get(light_commit.encrypted_path_secret); + const auto& decryption_node_index = + opt::get(light_commit.decryption_node_index); + + const auto priv = opt::get(_tree_priv.private_key(decryption_node_index)); + const auto context = tls::marshal(next.group_context()); + const auto path_secret = priv.decrypt(next._suite, + encrypt_label::update_path_node, + context, + encrypted_path_secret); + const auto ancestor = + next._index.ancestor(light_commit.sender_membership_proof.leaf_index); + next._tree_priv.implant_matching(next._tree, ancestor, path_secret); + + commit_secret = next._tree_priv.update_secret; + } + + // Update the key schedule + // TODO(RLB) Need to accommodate PSKs for light clients + next._transcript_hash = + TranscriptHash(next._suite, + light_commit.group_context.confirmed_transcript_hash, + light_commit.confirmation_tag); + next.update_epoch_secrets(commit_secret, {}, std::nullopt); + + return next; +} + /// /// Subgroup branching /// @@ -1007,6 +1141,7 @@ State::create_branch(bytes group_id, proposals, commit_opts.inline_tree, commit_opts.force_path, + commit_opts.membership_proof, commit_opts.leaf_node_opts, }; auto [_commit, welcome, state] = new_group.commit( @@ -1085,6 +1220,7 @@ State::Tombstone::create_welcome(HPKEPrivateKey enc_priv, proposals, commit_opts.inline_tree, commit_opts.force_path, + commit_opts.membership_proof, commit_opts.leaf_node_opts, }; auto [_commit, welcome, state] = new_group.commit( @@ -1985,12 +2121,12 @@ operator==(const State& lhs, const State& rhs) auto suite = (lhs._suite == rhs._suite); auto group_id = (lhs._group_id == rhs._group_id); auto epoch = (lhs._epoch == rhs._epoch); - auto tree = (lhs._tree == rhs._tree); + auto tree_hash = (lhs._tree.root_hash() == rhs._tree.root_hash()); auto transcript_hash = (lhs._transcript_hash == rhs._transcript_hash); auto key_schedule = (lhs._key_schedule == rhs._key_schedule); auto extensions = (lhs._extensions == rhs._extensions); - return suite && group_id && epoch && tree && transcript_hash && + return suite && group_id && epoch && tree_hash && transcript_hash && key_schedule && extensions; } @@ -2013,6 +2149,7 @@ State::update_epoch_secrets(const bytes& commit_secret, _transcript_hash.confirmed, _extensions, }); + _key_schedule = _key_schedule.next(commit_secret, psks, force_init_secret, ctx); _keys = _key_schedule.encryption_keys(_tree.size); diff --git a/src/tree_math.cpp b/src/tree_math.cpp index f1a8be95..524a16b9 100644 --- a/src/tree_math.cpp +++ b/src/tree_math.cpp @@ -47,8 +47,13 @@ LeafCount::full(const LeafCount n) } NodeCount::NodeCount(const LeafCount n) - : UInt32(2 * (n.val - 1) + 1) + : UInt32(0) { + if (n.val == 0) { + val = 0; + } else { + val = 2 * (n.val - 1) + 1; + } } LeafIndex::LeafIndex(NodeIndex x) @@ -165,7 +170,7 @@ NodeIndex::sibling(NodeIndex ancestor) const } std::vector -NodeIndex::dirpath(LeafCount n) +NodeIndex::dirpath(LeafCount n) const { if (val >= NodeCount(n).val) { throw InvalidParameterError("Request for dirpath outside of tree"); @@ -193,7 +198,7 @@ NodeIndex::dirpath(LeafCount n) } std::vector -NodeIndex::copath(LeafCount n) +NodeIndex::copath(LeafCount n) const { auto d = dirpath(n); if (d.empty()) { diff --git a/src/treekem.cpp b/src/treekem.cpp index 53aeea1c..69a3d303 100644 --- a/src/treekem.cpp +++ b/src/treekem.cpp @@ -57,6 +57,15 @@ Node::parent_hash() const return var::visit(get_ph, node); } +/// +/// TreeSlice +/// +bytes +TreeSlice::tree_hash(CipherSuite suite) const +{ + return TreeKEMPublicKey(suite, *this).root_hash(); +} + /// /// TreeKEMPrivateKey /// @@ -116,6 +125,30 @@ TreeKEMPrivateKey::implant(const TreeKEMPublicKey& pub, update_secret = pub.suite.derive_secret(secret, "path"); } +void +TreeKEMPrivateKey::implant_matching(const TreeKEMPublicKey& pub, + NodeIndex start, + const bytes& path_secret) +{ + auto secret = path_secret; + + path_secrets.insert_or_assign(start, secret); + private_key_cache.erase(start); + + const auto dp = start.dirpath(pub.size); + for (const auto& n : dp) { + if (pub.node_at(n).blank()) { + continue; + } + + secret = pub.suite.derive_secret(secret, "path"); + path_secrets.insert_or_assign(n, secret); + private_key_cache.erase(n); + } + + update_secret = pub.suite.derive_secret(secret, "path"); +} + std::optional TreeKEMPrivateKey::private_key(NodeIndex n) const { @@ -209,8 +242,13 @@ TreeKEMPublicKey::dump() const std::cout << "Tree:" << std::endl; auto width = NodeCount(size); for (auto i = NodeIndex{ 0 }; i.val < width.val; i.val++) { + const auto known = nodes.count(i) > 0; + const auto blank = known && node_at(i).blank(); + printf(" %03d : ", i.val); // NOLINT - if (!node_at(i).blank()) { + if (!known) { + std::cout << "????????"; + } else if (!blank) { auto pkRm = to_hex(opt::get(node_at(i).node).public_key().data); std::cout << pkRm.substr(0, 8); } else { @@ -222,7 +260,9 @@ TreeKEMPublicKey::dump() const std::cout << " "; } - if (!node_at(i).blank()) { + if (!known) { + std::cout << "?"; + } else if (!blank) { std::cout << "X"; if (!i.is_leaf()) { @@ -381,6 +421,14 @@ TreeKEMPublicKey::TreeKEMPublicKey(CipherSuite suite_in) { } +TreeKEMPublicKey::TreeKEMPublicKey(CipherSuite suite_in, const TreeSlice& slice) + : suite(suite_in) + , size(slice.n_leaves) +{ + implant_slice_unchecked(slice); + set_hash_all(); +} + LeafIndex TreeKEMPublicKey::allocate_leaf() { @@ -392,12 +440,17 @@ TreeKEMPublicKey::allocate_leaf() // Extend the tree if necessary if (index.val >= size.val) { + const auto prev_width = NodeCount(size); + if (size.val == 0) { size.val = 1; - nodes.resize(1); } else { size.val *= 2; - nodes.resize(2 * nodes.size() + 1); + } + + const auto new_width = NodeCount(size); + for (auto i = NodeIndex(prev_width.val); i < new_width; i.val++) { + nodes.insert_or_assign(i, OptionalNode{}); } } @@ -557,6 +610,12 @@ TreeKEMPublicKey::parent_hash_valid() const return true; } +bool +TreeKEMPublicKey::is_complete() const +{ + return nodes.size() == NodeCount{ size }.val; +} + std::vector TreeKEMPublicKey::resolve(NodeIndex index) const { @@ -586,6 +645,93 @@ TreeKEMPublicKey::resolve(NodeIndex index) const return l; } +TreeSlice +TreeKEMPublicKey::extract_slice(LeafIndex leaf) const +{ + if (!(leaf < size)) { + throw InvalidParameterError("Invalid leaf index"); + } + + const auto n = NodeIndex(leaf); + auto dirpath = n.dirpath(size); + dirpath.insert(dirpath.begin(), n); + const auto dirpath_nodes = stdx::transform( + dirpath, [this](const auto& n) { return node_at(n); }); + + const auto copath = n.copath(size); + const auto copath_hashes = stdx::transform( + copath, [this](const auto& n) { return hashes.at(n); }); + + return { leaf, size, dirpath_nodes, copath_hashes }; +} + +void +TreeKEMPublicKey::implant_slice(const TreeSlice& slice) +{ + if (slice.n_leaves != size) { + throw InvalidParameterError("Slice tree size does not match tree size"); + } + + if (slice.tree_hash(suite) != root_hash()) { + throw InvalidParameterError("Slice tree hash does not match tree hash"); + } + + implant_slice_unchecked(slice); +} + +std::tuple +TreeKEMPublicKey::slice_path(UpdatePath path, + LeafIndex from, + LeafIndex to) const +{ + const auto toi = NodeIndex(to); + const auto fdp = filtered_direct_path(NodeIndex(from)); + + for (auto i = size_t(0); i < fdp.size(); i++) { + const auto& [dpi, res] = fdp.at(i); + + if (!toi.is_below(dpi)) { + continue; + } + + for (auto j = size_t(0); j < res.size(); j++) { + const auto resi = res.at(j); + if (!toi.is_below(resi)) { + continue; + } + + return { path.nodes.at(i).encrypted_path_secret.at(j), resi }; + } + } + + throw ProtocolError("Decryption node not found"); +} + +void +TreeKEMPublicKey::implant_slice_unchecked(const TreeSlice& slice) +{ + const auto n = NodeIndex(slice.leaf_index); + auto dirpath = n.dirpath(size); + dirpath.insert(dirpath.begin(), n); + const auto copath = n.copath(size); + + if (slice.direct_path_nodes.size() != dirpath.size()) { + throw InvalidParameterError("Malformed tree slice (bad direct path size)"); + } + + if (slice.copath_hashes.size() != copath.size()) { + throw InvalidParameterError("Malformed tree slice (bad copath size)"); + } + + for (auto i = size_t(0); i < dirpath.size(); i++) { + nodes.insert_or_assign(dirpath.at(i), slice.direct_path_nodes.at(i)); + } + + for (auto i = size_t(0); i < copath.size(); i++) { + hashes.insert_or_assign(copath.at(i), slice.copath_hashes.at(i)); + } +} + TreeKEMPublicKey::FilteredDirectPath TreeKEMPublicKey::filtered_direct_path(NodeIndex index) const { @@ -617,6 +763,11 @@ std::optional TreeKEMPublicKey::find(const LeafNode& leaf) const { for (LeafIndex i{ 0 }; i < size; i.val++) { + if (nodes.count(NodeIndex{ i }) == 0) { + // Unknown leaf node + continue; + } + const auto& node = node_at(i); if (!node.blank() && node.leaf_node() == leaf) { return i; @@ -736,41 +887,32 @@ TreeKEMPublicKey::truncate() return; } - // Remove the right subtree until the tree is of minimal size + // Find the new size of the tree while (size.val / 2 > index.val) { - nodes.resize(nodes.size() / 2); size.val /= 2; } + + // Delete nodes to right of the new smaller edge of the tree + const auto node_size = NodeCount(size); + const auto start = + std::find_if(nodes.begin(), nodes.end(), [node_size](const auto& n) { + return !(n.first < node_size); + }); + if (start != nodes.end()) { + nodes.erase(start, nodes.end()); + } } OptionalNode& TreeKEMPublicKey::node_at(NodeIndex n) { - auto width = NodeCount(size); - if (n.val >= width.val) { - throw InvalidParameterError("Node index not in tree"); - } - - if (n.val >= nodes.size()) { - return blank_node; - } - - return nodes.at(n.val); + return nodes.at(n); } const OptionalNode& TreeKEMPublicKey::node_at(NodeIndex n) const { - auto width = NodeCount(size); - if (n.val >= width.val) { - throw InvalidParameterError("Node index not in tree"); - } - - if (n.val >= nodes.size()) { - return blank_node; - } - - return nodes.at(n.val); + return nodes.at(n); } OptionalNode& @@ -1068,9 +1210,14 @@ operator<<(tls::ostream& str, const TreeKEMPublicKey& obj) cut.val -= 1; } - const auto begin = obj.nodes.begin(); - const auto end = begin + NodeIndex(cut).val + 1; - const auto view = std::vector(begin, end); + auto node_cut = NodeIndex(cut); + node_cut.val += 1; + + auto view = std::vector(node_cut.val); + for (auto i = NodeIndex(0); i < node_cut; i.val++) { + view.at(i.val) = obj.nodes.at(i); + } + return str << view; } @@ -1078,37 +1225,46 @@ tls::istream& operator>>(tls::istream& str, TreeKEMPublicKey& obj) { // Read the node list - str >> obj.nodes; - if (obj.nodes.empty()) { + std::vector nodes; + str >> nodes; + if (nodes.empty()) { return str; } // Verify that the tree is well-formed and minimal - if (obj.nodes.size() % 2 == 0) { + if (nodes.size() % 2 == 0) { throw ProtocolError("Malformed ratchet tree: even number of nodes"); } - if (obj.nodes.back().blank()) { + if (nodes.back().blank()) { throw ProtocolError("Ratchet tree does not use minimal encoding"); } // Adjust the size value to fit the non-blank nodes obj.size.val = 1; - while (NodeCount(obj.size).val < obj.nodes.size()) { + while (NodeCount(obj.size).val < nodes.size()) { obj.size.val *= 2; } - // Add blank nodes to the end - obj.nodes.resize(NodeCount(obj.size).val); + // Copy nodes to `obj` and add blank nodes to the end + for (uint32_t i = 0; i < nodes.size(); i++) { + obj.nodes.insert_or_assign(NodeIndex(i), std::move(nodes.at(i))); + } + + const auto node_size = NodeCount(obj.size); + for (uint32_t i = nodes.size(); i < node_size.val; i++) { + obj.nodes.insert_or_assign(NodeIndex(i), OptionalNode{}); + } // Verify the basic structure of the tree is sane - for (size_t i = 0; i < obj.nodes.size(); i++) { - if (obj.nodes[i].blank()) { + for (auto i = NodeIndex{ 0 }; i < node_size; i.val++) { + const auto& maybe_node = obj.node_at(i); + if (maybe_node.blank()) { continue; } - const auto& node = opt::get(obj.nodes[i].node).node; - auto at_leaf = (i % 2 == 0); + const auto& node = opt::get(maybe_node.node).node; + auto at_leaf = (i.val % 2 == 0); auto holds_leaf = var::holds_alternative(node); auto holds_parent = var::holds_alternative(node); diff --git a/test/messages.cpp b/test/messages.cpp index a2f514fa..d6db8d8f 100644 --- a/test/messages.cpp +++ b/test/messages.cpp @@ -23,6 +23,35 @@ TEST_CASE("Extensions") REQUIRE(tree0 == tree1); } +TEST_CASE("Flags Extension") +{ + auto flags = FlagsExtension{}; + + const auto flag_data_3 = from_hex("08"); + const auto flag_data_3_5 = from_hex("28"); + const auto flag_data_3_5_23 = from_hex("280080"); + + // Following the example from TLS: + // https://datatracker.ietf.org/doc/html/draft-ietf-tls-tlsflags-13#section-2 + flags.set(3); + REQUIRE(flags.flag_data == flag_data_3); + + flags.set(5); + REQUIRE(flags.flag_data == flag_data_3_5); + + flags.set(23); + REQUIRE(flags.flag_data == flag_data_3_5_23); + + flags.unset(23); + REQUIRE(flags.flag_data == flag_data_3_5); + + flags.unset(5); + REQUIRE(flags.flag_data == flag_data_3); + + flags.unset(3); + REQUIRE(flags.flag_data.empty()); +} + // TODO(RLB) Verify sign/verify on: // * KeyPackage // * GroupInfo diff --git a/test/state.cpp b/test/state.cpp index 0f922ff2..ca0519dc 100644 --- a/test/state.cpp +++ b/test/state.cpp @@ -131,8 +131,8 @@ TEST_CASE_METHOD(StateTest, "Two Person") // Handle the Add proposal and create a Commit auto add = first0.add_proposal(key_packages[1]); - auto [commit, welcome, first1_] = - first0.commit(fresh_secret(), CommitOpts{ { add }, true, false, {} }, {}); + auto [commit, welcome, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add }, true, false, false, {} }, {}); silence_unused(commit); auto first1 = first1_; @@ -164,7 +164,7 @@ TEST_CASE_METHOD(StateTest, "Two Person with New Member Add") auto add = State::new_member_add(group_id, 0, key_packages[1], identity_privs[1]); first0.handle(add); - auto opts = CommitOpts{ {}, true, false, {} }; + auto opts = CommitOpts{ {}, true, false, false, {} }; auto [commit, welcome, first1_] = first0.commit(fresh_secret(), opts, {}); silence_unused(commit); auto first1 = first1_; @@ -209,7 +209,7 @@ TEST_CASE_METHOD(StateTest, "Two Person with External Proposal") // Handle the Add proposal and create a Commit first0.handle(add); - auto opts = CommitOpts{ {}, true, false, {} }; + auto opts = CommitOpts{ {}, true, false, false, {} }; auto [commit, welcome, first1_] = first0.commit(fresh_secret(), opts, {}); silence_unused(commit); auto first1 = first1_; @@ -243,8 +243,8 @@ TEST_CASE_METHOD(StateTest, "Two Person with custom extensions") // Handle the Add proposal and create a Commit auto add = first0.add_proposal(key_packages[1]); - auto [commit1, welcome1, first1_] = - first0.commit(fresh_secret(), CommitOpts{ { add }, true, false, {} }, {}); + auto [commit1, welcome1, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add }, true, false, false, {} }, {}); auto first1 = first1_; silence_unused(commit1); @@ -267,8 +267,8 @@ TEST_CASE_METHOD(StateTest, "Two Person with custom extensions") second_exts.add(CustomExtension2{ 0xb0 }); auto gce = first1.group_context_extensions_proposal(second_exts); - auto [commit2, welcome2, first2_] = - first1.commit(fresh_secret(), CommitOpts{ { gce }, false, false, {} }, {}); + auto [commit2, welcome2, first2_] = first1.commit( + fresh_secret(), CommitOpts{ { gce }, false, false, false, {} }, {}); auto second2 = second1.handle(commit2); silence_unused(welcome2); auto first2 = first2_; @@ -289,8 +289,8 @@ TEST_CASE_METHOD(StateTest, "Two Person with external tree for welcome") // Handle the Add proposal and create a Commit auto add = first0.add_proposal(key_packages[1]); // Don't generate RatchetTree extension - auto [commit, welcome_, first1_] = - first0.commit(fresh_secret(), CommitOpts{ { add }, false, false, {} }, {}); + auto [commit, welcome_, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add }, false, false, false, {} }, {}); auto welcome = welcome_; auto first1 = first1_; silence_unused(commit); @@ -353,7 +353,7 @@ TEST_CASE_METHOD(StateTest, "Two Person with PSK") auto add = first0.add_proposal(key_packages[1]); auto psk = first0.pre_shared_key_proposal(psk_id); auto [commit, welcome, first1_] = first0.commit( - fresh_secret(), CommitOpts{ { add, psk }, true, false, {} }, {}); + fresh_secret(), CommitOpts{ { add, psk }, true, false, false, {} }, {}); silence_unused(commit); auto first1 = first1_; @@ -377,8 +377,8 @@ TEST_CASE_METHOD(StateTest, "Two Person with Replacement") // Handle the Add proposal and create a Commit const auto add1 = first0.add_proposal(key_packages[1]); - const auto [commit1, welcome1, first1_] = - first0.commit(fresh_secret(), CommitOpts{ { add1 }, true, false, {} }, {}); + const auto [commit1, welcome1, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add1 }, true, false, false, {} }, {}); silence_unused(commit1); auto first1 = first1_; @@ -404,8 +404,10 @@ TEST_CASE_METHOD(StateTest, "Two Person with Replacement") // Create a commit replacing the first member const auto remove2 = second1.remove_proposal(LeafIndex{ 0 }); const auto add2 = second1.add_proposal(key_package); - const auto [commit2, welcome2, second2_] = second1.commit( - fresh_secret(), CommitOpts{ { add2, remove2 }, true, false, {} }, {}); + const auto [commit2, welcome2, second2_] = + second1.commit(fresh_secret(), + CommitOpts{ { add2, remove2 }, true, false, false, {} }, + {}); auto second2 = second2_; silence_unused(commit2); @@ -419,6 +421,78 @@ TEST_CASE_METHOD(StateTest, "Two Person with Replacement") verify_group_functionality(group); } +TEST_CASE_METHOD(StateTest, "Light client can join and participate") +{ + // Initialize the creator's state + auto first0 = State{ group_id, + suite, + leaf_privs[0], + identity_privs[0], + key_packages[0].leaf_node, + {} }; + + // Add the second participant + auto add1 = first0.add_proposal(key_packages[1]); + auto [commit1, welcome1, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add1 }, true, false, false, {} }, {}); + silence_unused(commit1); + auto first1 = first1_; + + // Initialize the second participant from the Welcome. Note that the second + // participant is always a full client, because the membership proofs cover + // the whole tree. + auto second1 = State{ init_privs[1], + leaf_privs[1], + identity_privs[1], + key_packages[1], + welcome1, + std::nullopt, + {} }; + + REQUIRE(first1 == second1); + + // Add the third participant + auto add2 = first0.add_proposal(key_packages[2]); + auto [commit2, welcome2, first2_] = first1.commit( + fresh_secret(), CommitOpts{ { add2 }, false, false, true, {} }, {}); + auto first2 = first2_; + + // Handle the Commit at the second participant + auto second2 = opt::get(second1.handle(commit2)); + + // Initialize the third participant as a light client, by only including + // membership proofs in the Welcome, not the full tree + auto third2 = State{ init_privs[2], + leaf_privs[2], + identity_privs[2], + key_packages[2], + welcome2, + std::nullopt, + {} }; + REQUIRE_FALSE(third2.is_full_client()); + + REQUIRE(first2 == second2); + REQUIRE(first2 == third2); + + // Create another commit and handle it at the second client + auto [commit3, welcome3, first3_] = first2.commit(fresh_secret(), {}, {}); + silence_unused(welcome3); + auto first3 = first3_; + auto second3 = opt::get(second2.handle(commit3)); + + // Verify that the light client refuses to process it on its own + REQUIRE_THROWS(third2.handle(commit3)); + + // Convert the Commit to a LightCommit + auto light_commit = first3.lighten_for(third2.index(), commit3); + + // Verify that the light client can process the commit with a commit map + auto third3 = third2.handle(light_commit); + + REQUIRE(first3 == second3); + REQUIRE(first3 == third3); +} + TEST_CASE_METHOD(StateTest, "External Join") { // Initialize the creator's state @@ -524,8 +598,8 @@ TEST_CASE_METHOD(StateTest, "External Join with Eviction of Prior Appearance") // Add the second participant auto add = first0.add_proposal(key_packages[1]); - auto [commit1, welcome1, first1] = - first0.commit(fresh_secret(), CommitOpts{ { add }, true, false, {} }, {}); + auto [commit1, welcome1, first1] = first0.commit( + fresh_secret(), CommitOpts{ { add }, true, false, false, {} }, {}); silence_unused(commit1); auto second1 = State{ init_privs[1], leaf_privs[1], @@ -583,8 +657,8 @@ TEST_CASE_METHOD(StateTest, "SFrame Parameter Negotiation") // Add the second member auto add = first0.add_proposal(kp1); - auto [commit, welcome, first1_] = - first0.commit(fresh_secret(), CommitOpts{ { add }, true, false, {} }, {}); + auto [commit, welcome, first1_] = first0.commit( + fresh_secret(), CommitOpts{ { add }, true, false, false, {} }, {}); auto first1 = first1_; silence_unused(commit); @@ -677,8 +751,8 @@ TEST_CASE_METHOD(StateTest, "Add Multiple Members") } // Create a Commit that adds everybody - auto [commit, welcome, new_state] = - states[0].commit(fresh_secret(), CommitOpts{ adds, true, false, {} }, {}); + auto [commit, welcome, new_state] = states[0].commit( + fresh_secret(), CommitOpts{ adds, true, false, false, {} }, {}); silence_unused(commit); states[0] = new_state; @@ -712,7 +786,7 @@ TEST_CASE_METHOD(StateTest, "Full Size Group") auto add = states[sender].add_proposal(key_packages[i]); auto [commit, welcome, new_state] = states[sender].commit( - fresh_secret(), CommitOpts{ { add }, true, false, {} }, {}); + fresh_secret(), CommitOpts{ { add }, true, false, false, {} }, {}); for (size_t j = 0; j < states.size(); j += 1) { if (j == sender) { states[j] = new_state; @@ -757,8 +831,8 @@ class RunningGroupTest : public StateTest adds.push_back(states[0].add_proposal(key_packages[i])); } - auto [commit, welcome, new_state] = - states[0].commit(fresh_secret(), CommitOpts{ adds, true, false, {} }, {}); + auto [commit, welcome, new_state] = states[0].commit( + fresh_secret(), CommitOpts{ adds, true, false, false, {} }, {}); silence_unused(commit); states[0] = new_state; for (size_t i = 1; i < group_size; i += 1) { @@ -840,7 +914,7 @@ TEST_CASE_METHOD(RunningGroupTest, "Add a PSK from Everyone in a Group") auto psk = states[i].pre_shared_key_proposal(psk_id); auto [commit, welcome, new_state] = states[i].commit( - fresh_secret(), CommitOpts{ { psk }, false, false, {} }, {}); + fresh_secret(), CommitOpts{ { psk }, false, false, false, {} }, {}); silence_unused(welcome); for (auto& state : states) { @@ -861,7 +935,7 @@ TEST_CASE_METHOD(RunningGroupTest, "Remove Members from a Group") for (uint32_t i = uint32_t(group_size) - 2; i > 0; i -= 1) { auto remove = states[i].remove_proposal(LeafIndex{ i + 1 }); auto [commit, welcome, new_state] = states[i].commit( - fresh_secret(), CommitOpts{ { remove }, false, false, {} }, {}); + fresh_secret(), CommitOpts{ { remove }, false, false, false, {} }, {}); silence_unused(welcome); states.pop_back(); @@ -887,7 +961,7 @@ TEST_CASE_METHOD(RunningGroupTest, "Roster Updates") // remove member at position 1 auto remove_1 = states[0].remove_proposal(RosterIndex{ 1 }); auto [commit_1, welcome_1, new_state_1_] = states[0].commit( - fresh_secret(), CommitOpts{ { remove_1 }, true, false, {} }, {}); + fresh_secret(), CommitOpts{ { remove_1 }, true, false, false, {} }, {}); auto new_state_1 = new_state_1_; silence_unused(welcome_1); silence_unused(commit_1); @@ -903,7 +977,7 @@ TEST_CASE_METHOD(RunningGroupTest, "Roster Updates") // remove member at position 2 auto remove_2 = new_state_1.remove_proposal(RosterIndex{ 2 }); auto [commit_2, welcome_2, new_state_2_] = new_state_1.commit( - fresh_secret(), CommitOpts{ { remove_2 }, true, false, {} }, {}); + fresh_secret(), CommitOpts{ { remove_2 }, true, false, false, {} }, {}); auto new_state_2 = new_state_2_; silence_unused(welcome_2); // roster should be 0, 2, 4 diff --git a/test/treekem.cpp b/test/treekem.cpp index 37601e53..f9a3cdd3 100644 --- a/test/treekem.cpp +++ b/test/treekem.cpp @@ -68,7 +68,7 @@ TEST_CASE_METHOD(TreeKEMTest, "TreeKEM Private Key") const auto priv2 = HPKEPrivateKey::generate(suite); const auto hash_size = suite.digest().hash_size; - // Create a tree with N blank leaves + // Create a tree with N leaves, blank otherwise auto pub = TreeKEMPublicKey(suite); for (auto i = uint32_t(0); i < size.val; i++) { auto [_leaf_priv, _sig_priv, leaf_node] = new_leaf_node(); @@ -264,6 +264,32 @@ TEST_CASE_METHOD(TreeKEMTest, "TreeKEM encap/decap") REQUIRE(privs[j].consistent(pubs[j])); } } + + // XXX(RLB) The below test probably doesn't belong here, but we have a nice + // properly-created tree here, so it's convenient to use it for testing tree + // slices. + + // Verify that all slices of the tree have the correct tree hash + auto original = pubs[0]; + original.set_hash_all(); + const auto tree_hash = original.root_hash(); + auto slices = std::vector{}; + for (uint32_t i = 0; i < original.size.val; i++) { + auto slice = original.extract_slice(LeafIndex{ i }); + REQUIRE(slice.tree_hash(suite) == tree_hash); + + slices.push_back(slice); + } + + // Verify that the tree can be reassembled from the slices + auto reconstructed = TreeKEMPublicKey(suite, slices[0]); + for (uint32_t i = 1; i < slices.size(); i++) { + reconstructed.implant_slice(slices[i]); + } + + reconstructed.set_hash_all(); + REQUIRE(reconstructed.root_hash() == tree_hash); + REQUIRE(reconstructed.parent_hash_valid()); } TEST_CASE("TreeKEM Interop")